Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 46 additions & 38 deletions lib/starlark_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ type SymmetricValues struct {
starlark.Tuple
}

func (s SymmetricValues) Index(i int) SymmetricValue { return s.Tuple.Index(i).(SymmetricValue) }
func (s SymmetricValues) Type() string { return "symmetric_values" }
func (s SymmetricValues) Index(i int) *SymmetricValue { return s.Tuple.Index(i).(*SymmetricValue) }
func (s SymmetricValues) Type() string { return "symmetric_values" }

func MakeSymmetricValues(thread *starlark.Thread, fn *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) {
prefix := ""
Expand All @@ -123,7 +123,7 @@ func MakeSymmetricValues(thread *starlark.Thread, fn *starlark.Builtin, args sta
return values, nil
}

func NewSymmetricValues(values []SymmetricValue) *SymmetricValues {
func NewSymmetricValues(values []*SymmetricValue) *SymmetricValues {
sv := &SymmetricValues{}
for _, v := range values {
sv.Tuple = append(sv.Tuple, v)
Expand Down Expand Up @@ -1192,8 +1192,8 @@ var _ starlark.Value = (*Bag)(nil)
// There is a limit in the API, this is required to get 'in' keyword to work
var _ starlark.Mapping = (*Bag)(nil)

func NewModelValue(prefix string, i int64) ModelValue {
return ModelValue{
func NewModelValue(prefix string, i int64) *ModelValue {
return &ModelValue{
prefix: prefix,
id: i,
}
Expand All @@ -1206,16 +1206,20 @@ type ModelValue struct {
id int64
}

func (m ModelValue) GetPrefix() string {
func (m *ModelValue) GetPrefix() string {
return m.prefix
}

func (m ModelValue) GetId() int64 {
func (m *ModelValue) GetId() int64 {
return m.id
}

func (m ModelValue) CompareSameType(op syntax.Token, y_ starlark.Value, depth int) (bool, error) {
y := y_.(ModelValue)
func (m *ModelValue) SetId(id int64) {
m.id = id
}

func (m *ModelValue) CompareSameType(op syntax.Token, y_ starlark.Value, depth int) (bool, error) {
y := y_.(*ModelValue)
switch op {
case syntax.EQL:
return modelValueEqual(m, y), nil
Expand All @@ -1226,57 +1230,57 @@ func (m ModelValue) CompareSameType(op syntax.Token, y_ starlark.Value, depth in
}
}

func modelValueEqual(s ModelValue, y ModelValue) bool {
func modelValueEqual(s *ModelValue, y *ModelValue) bool {
return s.prefix == y.prefix && s.id == y.id
}

func (m ModelValue) String() string {
func (m *ModelValue) String() string {
return fmt.Sprintf("%s%d", m.prefix, m.id)
}

func (m ModelValue) FullString() string {
func (m *ModelValue) FullString() string {
return fmt.Sprintf("%s%d", m.prefix, m.id)
}

func (m ModelValue) ShortString() string {
func (m *ModelValue) ShortString() string {
return fmt.Sprintf("%s%d", m.prefix, m.id)
}

func (m ModelValue) MarshalJSON() ([]byte, error) {
func (m *ModelValue) MarshalJSON() ([]byte, error) {
return json.Marshal(m.FullString())
}

func (m ModelValue) Type() string {
func (m *ModelValue) Type() string {
return ModelValueType
}

func (m ModelValue) Freeze() {}
func (m *ModelValue) Freeze() {}

func (m ModelValue) Truth() starlark.Bool {
func (m *ModelValue) Truth() starlark.Bool {
return true
}

func (m ModelValue) Hash() (uint32, error) {
func (m *ModelValue) Hash() (uint32, error) {
return starlark.String(m.FullString()).Hash()
}

var _ starlark.Value = ModelValue{}
var _ starlark.Comparable = ModelValue{}
var _ starlark.Value = (*ModelValue)(nil)
var _ starlark.Comparable = (*ModelValue)(nil)

// NewSymmetricValue creates a new symmetric value with the default Nominal kind.
// This maintains backward compatibility with existing code.
func NewSymmetricValue(prefix string, i int64) SymmetricValue {
return SymmetricValue{ModelValue: ModelValue{prefix: prefix, id: i}, Kind: SymmetryKindNominal}
func NewSymmetricValue(prefix string, i int64) *SymmetricValue {
return &SymmetricValue{ModelValue: ModelValue{prefix: prefix, id: i}, Kind: SymmetryKindNominal}
}

// NewSymmetricValueWithKind creates a new symmetric value with the specified kind.
func NewSymmetricValueWithKind(prefix string, i int64, kind SymmetryKind) SymmetricValue {
return SymmetricValue{ModelValue: ModelValue{prefix: prefix, id: i}, Kind: kind}
func NewSymmetricValueWithKind(prefix string, i int64, kind SymmetryKind) *SymmetricValue {
return &SymmetricValue{ModelValue: ModelValue{prefix: prefix, id: i}, Kind: kind}
}

// NewRotationalSymmetricValue creates a new symmetric value for rotational domains with the limit set.
func NewRotationalSymmetricValue(prefix string, i int64, limit int) SymmetricValue {
return SymmetricValue{ModelValue: ModelValue{prefix: prefix, id: i}, Kind: SymmetryKindRotational, Limit: limit}
func NewRotationalSymmetricValue(prefix string, i int64, limit int) *SymmetricValue {
return &SymmetricValue{ModelValue: ModelValue{prefix: prefix, id: i}, Kind: SymmetryKindRotational, Limit: limit}
}

const SymmetricValueType = "symmetric_value"
Expand All @@ -1287,12 +1291,12 @@ type SymmetricValue struct {
Limit int // Only used for rotational kind (domain size for mod arithmetic). 0 for other kinds.
}

func (s SymmetricValue) GetKind() SymmetryKind {
func (s *SymmetricValue) GetKind() SymmetryKind {
return s.Kind
}

func (s SymmetricValue) CompareSameType(op syntax.Token, y_ starlark.Value, depth int) (bool, error) {
y := y_.(SymmetricValue)
func (s *SymmetricValue) CompareSameType(op syntax.Token, y_ starlark.Value, depth int) (bool, error) {
y := y_.(*SymmetricValue)

// Values from different domains cannot be compared
if s.prefix != y.prefix {
Expand All @@ -1301,9 +1305,9 @@ func (s SymmetricValue) CompareSameType(op syntax.Token, y_ starlark.Value, dept

switch op {
case syntax.EQL:
return modelValueEqual(s.ModelValue, y.ModelValue), nil
return symmetricValueEqual(s, y), nil
case syntax.NEQ:
return !modelValueEqual(s.ModelValue, y.ModelValue), nil
return !symmetricValueEqual(s, y), nil
case syntax.LT, syntax.LE, syntax.GT, syntax.GE:
// Ordering only allowed for Ordinal and Interval kinds
if s.Kind == SymmetryKindNominal {
Expand All @@ -1326,18 +1330,22 @@ func (s SymmetricValue) CompareSameType(op syntax.Token, y_ starlark.Value, dept
return false, fmt.Errorf("%s %s %s not implemented", s.Type(), op, y.Type())
}

func (s SymmetricValue) Type() string {
func symmetricValueEqual(s *SymmetricValue, y *SymmetricValue) bool {
return s.prefix == y.prefix && s.id == y.id
}

func (s *SymmetricValue) Type() string {
return SymmetricValueType
}

var _ starlark.Value = SymmetricValue{}
var _ starlark.Comparable = SymmetricValue{}
var _ starlark.Value = (*SymmetricValue)(nil)
var _ starlark.Comparable = (*SymmetricValue)(nil)

func CompareStringer[E fmt.Stringer](a, b E) int {
return strings.Compare(a.String(), b.String())
}

func (s SymmetricValue) Binary(op syntax.Token, y starlark.Value, side starlark.Side) (starlark.Value, error) {
func (s *SymmetricValue) Binary(op syntax.Token, y starlark.Value, side starlark.Side) (starlark.Value, error) {
if s.Kind == SymmetryKindRotational {
return s.binaryRotational(op, y, side)
}
Expand All @@ -1357,7 +1365,7 @@ func (s SymmetricValue) Binary(op syntax.Token, y starlark.Value, side starlark.
return NewSymmetricValueWithKind(s.prefix, s.id-int64(i), s.Kind), nil
}
// SymmetricValue - SymmetricValue -> int (signed difference)
if other, ok := y.(SymmetricValue); ok {
if other, ok := y.(*SymmetricValue); ok {
if s.prefix != other.prefix {
return nil, fmt.Errorf("cannot subtract values from different domains: %s vs %s", s.prefix, other.prefix)
}
Expand All @@ -1373,7 +1381,7 @@ func modPositive(a, m int64) int64 {
return ((a % m) + m) % m
}

func (s SymmetricValue) binaryRotational(op syntax.Token, y starlark.Value, side starlark.Side) (starlark.Value, error) {
func (s *SymmetricValue) binaryRotational(op syntax.Token, y starlark.Value, side starlark.Side) (starlark.Value, error) {
limit := int64(s.Limit)
if limit <= 0 {
return nil, fmt.Errorf("rotational value has invalid limit %d", s.Limit)
Expand All @@ -1390,7 +1398,7 @@ func (s SymmetricValue) binaryRotational(op syntax.Token, y starlark.Value, side
return NewRotationalSymmetricValue(s.prefix, modPositive(s.id-int64(i), limit), s.Limit), nil
}
// SymmetricValue - SymmetricValue -> plain int (mod limit)
if other, ok := y.(SymmetricValue); ok {
if other, ok := y.(*SymmetricValue); ok {
if s.prefix != other.prefix {
return nil, fmt.Errorf("cannot subtract values from different domains: %s vs %s", s.prefix, other.prefix)
}
Expand Down
2 changes: 1 addition & 1 deletion lib/symmetry.go
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ func (d *SymmetryDomain) segments(thread *starlark.Thread, _ *starlark.Builtin,

// Helper to extract int64 ID from value
getID := func(v starlark.Value) (int64, bool) {
if sv, ok := v.(SymmetricValue); ok {
if sv, ok := v.(*SymmetricValue); ok {
if sv.prefix == d.Name {
return sv.id, true
}
Expand Down
2 changes: 1 addition & 1 deletion modelchecker/channel_message.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func (cm *ChannelMessage) HashCode() string {
return fmt.Sprintf("%x", h.Sum(nil))
}

func (cm *ChannelMessage) Clone(refs map[starlark.Value]starlark.Value, permutations map[lib.SymmetricValue][]lib.SymmetricValue, alt int) *ChannelMessage {
func (cm *ChannelMessage) Clone(refs map[starlark.Value]starlark.Value, permutations map[*lib.SymmetricValue][]*lib.SymmetricValue, alt int) *ChannelMessage {
frame, err := cm.frame.Clone(refs, permutations, alt)
PanicOnError(err)
params := CloneDict(cm.params, refs, permutations, alt)
Expand Down
58 changes: 43 additions & 15 deletions modelchecker/clone.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ func deepCloneStarlarkValue(value starlark.Value, refs map[starlark.Value]starla
return deepCloneStarlarkValueWithPermutations(value, refs, nil, 0)
}

func deepCloneStarlarkValueWithPermutations(value starlark.Value, refs map[starlark.Value]starlark.Value, permutations map[lib.SymmetricValue][]lib.SymmetricValue, alt int) (starlark.Value, error) {
func deepCloneStarlarkValueWithPermutations(value starlark.Value, refs map[starlark.Value]starlark.Value, permutations map[*lib.SymmetricValue][]*lib.SymmetricValue, alt int) (starlark.Value, error) {
if refs == nil {
refs = make(map[starlark.Value]starlark.Value)
} else if reflect.TypeOf(value).Kind() != reflect.Ptr {
Expand All @@ -25,19 +25,37 @@ func deepCloneStarlarkValueWithPermutations(value starlark.Value, refs map[starl
// "builtin_function_or_method".
switch value.Type() {

case "NoneType", "int", "float", "bool", "string", "bytes", "function", "range", "struct", "symmetric_values", "model_value", "Channel", "symmetry_domain":
case "NoneType", "int", "float", "bool", "string", "bytes", "function", "range", "struct", "symmetric_values", "Channel", "symmetry_domain":
//case starlark.Bool, starlark.String, starlark.Int:
// For simple values, just return a copy
// Also starlark struct is immutable
// symmetry_domain is stateless configuration (Name, Limit, DomainType)
return value, nil

case "model_value":
// ModelValue is now a pointer type - clone it as a reference
v := value.(*lib.ModelValue)
newVal := lib.NewModelValue(v.GetPrefix(), v.GetId())
refs[value] = newVal
return newVal, nil

case "symmetric_value":
// SymmetricValue is now a pointer type - clone it as a reference
v := value.(*lib.SymmetricValue)
if permutations != nil {
v := value.(lib.SymmetricValue)
// First try direct pointer lookup
if other, ok := permutations[v]; ok {
// Return the permuted value - also add to refs for consistency
refs[value] = other[alt]
return other[alt], nil
}
// Fall back to value-based lookup (for cases where SymmetricValues are created on-the-fly)
for k, other := range permutations {
if k.GetPrefix() == v.GetPrefix() && k.GetId() == v.GetId() {
refs[value] = other[alt]
return other[alt], nil
}
}
panic(fmt.Sprintf("symmetric_value %#v (Kind: %d) should be in %v and alt %v. Keys: %+v", v, v.Kind, permutations, alt, func() []string {
keys := make([]string, 0, len(permutations))
for k := range permutations {
Expand All @@ -46,7 +64,13 @@ func deepCloneStarlarkValueWithPermutations(value starlark.Value, refs map[starl
return keys
}()))
}
return value, nil
// Clone the symmetric value - now it's a pointer type
newVal := lib.NewSymmetricValueWithKind(v.GetPrefix(), v.GetId(), v.GetKind())
if v.Kind == lib.SymmetryKindRotational {
newVal = lib.NewRotationalSymmetricValue(v.GetPrefix(), v.GetId(), v.Limit)
}
refs[value] = newVal
return newVal, nil
case "list":
newVal := starlark.NewList(make([]starlark.Value, 0))
refs[value] = newVal
Expand Down Expand Up @@ -121,16 +145,20 @@ func deepCloneStarlarkValueWithPermutations(value starlark.Value, refs map[starl
newRight := s.Right

if permutations != nil {
// Map Left ID
leftSV := lib.NewSymmetricValueWithKind(s.Domain.Name, s.Left, s.Domain.Kind)
if other, ok := permutations[leftSV]; ok {
newLeft = other[alt].GetId()
// Map Left ID - look up by pointer in permutations
for oldSV, newSVs := range permutations {
if oldSV.GetPrefix() == s.Domain.Name && oldSV.GetId() == s.Left {
newLeft = newSVs[alt].GetId()
break
}
}

// Map Right ID
rightSV := lib.NewSymmetricValueWithKind(s.Domain.Name, s.Right, s.Domain.Kind)
if other, ok := permutations[rightSV]; ok {
newRight = other[alt].GetId()
// Map Right ID - look up by pointer in permutations
for oldSV, newSVs := range permutations {
if oldSV.GetPrefix() == s.Domain.Name && oldSV.GetId() == s.Right {
newRight = newSVs[alt].GetId()
break
}
}
}

Expand Down Expand Up @@ -201,7 +229,7 @@ func deepCloneStarlarkValueWithPermutations(value starlark.Value, refs map[starl
if err != nil {
return nil, err
}
newRoleId := newVal.(lib.SymmetricValue)
newRoleId := newVal.(*lib.SymmetricValue)
prefix = newRoleId.GetPrefix()
id = newRoleId.GetId()
}
Expand Down Expand Up @@ -236,7 +264,7 @@ func deepCloneStarlarkValueWithPermutations(value starlark.Value, refs map[starl
}

func deepCloneStringDict(v *starlark.Dict, refs map[starlark.Value]starlark.Value,
src map[lib.SymmetricValue][]lib.SymmetricValue, alt int) (*starlark.Dict, error) {
src map[*lib.SymmetricValue][]*lib.SymmetricValue, alt int) (*starlark.Dict, error) {

newDict := &starlark.Dict{}
refs[v] = newDict
Expand All @@ -255,7 +283,7 @@ func deepCloneStringDict(v *starlark.Dict, refs map[starlark.Value]starlark.Valu
return newDict, nil
}

func deepCloneIterableToList(iterable starlark.Iterable, refs map[starlark.Value]starlark.Value, permutations map[lib.SymmetricValue][]lib.SymmetricValue, alt int) ([]starlark.Value, error) {
func deepCloneIterableToList(iterable starlark.Iterable, refs map[starlark.Value]starlark.Value, permutations map[*lib.SymmetricValue][]*lib.SymmetricValue, alt int) ([]starlark.Value, error) {
var newList []starlark.Value
iter := iterable.Iterate()
defer iter.Done()
Expand Down
Loading