diff --git a/deps/deps.go b/deps/deps.go index e5dc339a4..da0db6e9d 100644 --- a/deps/deps.go +++ b/deps/deps.go @@ -2,6 +2,7 @@ package deps import ( + "bytes" "context" "crypto/rand" "database/sql" @@ -429,6 +430,11 @@ func LoadConfigWithUpgrades(text string, curioConfigWithDefaults *config.CurioCo return meta, err } func GetConfig(ctx context.Context, layers []string, db *harmonydb.DB) (*config.CurioConfig, error) { + err := updateBaseLayer(ctx, db) + if err != nil { + return nil, err + } + curioConfig := config.DefaultCurioConfig() have := []string{} layers = append([]string{"base"}, layers...) // Always stack on top of "base" layer @@ -462,6 +468,131 @@ func GetConfig(ctx context.Context, layers []string, db *harmonydb.DB) (*config. return curioConfig, nil } +func updateBaseLayer(ctx context.Context, db *harmonydb.DB) error { + _, err := db.BeginTransaction(ctx, func(tx *harmonydb.Tx) (commit bool, err error) { + // Get existing base from DB + text := "" + err = tx.QueryRow(`SELECT config FROM harmony_config WHERE title=$1`, "base").Scan(&text) + if err != nil { + if strings.Contains(err.Error(), sql.ErrNoRows.Error()) { + return false, fmt.Errorf("missing layer 'base' ") + } + return false, fmt.Errorf("could not read layer 'base': %w", err) + } + + // Load the existing configuration + cfg := config.DefaultCurioConfig() + metadata, err := LoadConfigWithUpgrades(text, cfg) + if err != nil { + return false, fmt.Errorf("could not read base layer, bad toml %s: %w", text, err) + } + + // Capture unknown fields + keys := removeUnknownEntries(metadata.Keys(), metadata.Undecoded()) + unrecognizedFields := extractUnknownFields(keys, text) + + // Convert the updated config back to TOML string + cb, err := config.ConfigUpdate(cfg, config.DefaultCurioConfig(), config.Commented(true), config.DefaultKeepUncommented(), config.NoEnv()) + if err != nil { + return false, xerrors.Errorf("cannot update base config: %w", err) + } + + // Merge unknown fields back into the updated config + finalConfig, err := mergeUnknownFields(string(cb), unrecognizedFields) + if err != nil { + return false, xerrors.Errorf("cannot merge unknown fields: %w", err) + } + + // Check if we need to update the DB + if text == finalConfig { + return false, nil + } + + // Save the updated base with merged comments + _, err = tx.Exec("UPDATE harmony_config SET config=$1 WHERE title='base'", finalConfig) + if err != nil { + return false, xerrors.Errorf("cannot update base config: %w", err) + } + + return true, nil + }, harmonydb.OptionRetry()) + + if err != nil { + return err + } + + return nil +} + +func extractUnknownFields(knownKeys []toml.Key, originalConfig string) map[string]interface{} { + // Parse the original config into a raw map + var rawConfig map[string]interface{} + err := toml.Unmarshal([]byte(originalConfig), &rawConfig) + if err != nil { + log.Warnw("Failed to parse original config for unknown fields", "error", err) + return nil + } + + // Collect all recognized keys + recognizedKeys := map[string]struct{}{} + for _, key := range knownKeys { + recognizedKeys[strings.Join(key, ".")] = struct{}{} + } + + // Identify unrecognized fields + unrecognizedFields := map[string]interface{}{} + for key, value := range rawConfig { + if _, recognized := recognizedKeys[key]; !recognized { + unrecognizedFields[key] = value + } + } + return unrecognizedFields +} + +func removeUnknownEntries(array1, array2 []toml.Key) []toml.Key { + // Create a set from array2 for fast lookup + toRemove := make(map[string]struct{}, len(array2)) + for _, key := range array2 { + toRemove[key.String()] = struct{}{} + } + + // Filter array1, keeping only elements not in toRemove + var result []toml.Key + for _, key := range array1 { + if _, exists := toRemove[key.String()]; !exists { + result = append(result, key) + } + } + + return result +} + +func mergeUnknownFields(updatedConfig string, unrecognizedFields map[string]interface{}) (string, error) { + // Parse the updated config into a raw map + var updatedConfigMap map[string]interface{} + err := toml.Unmarshal([]byte(updatedConfig), &updatedConfigMap) + if err != nil { + return "", fmt.Errorf("failed to parse updated config: %w", err) + } + + // Merge unrecognized fields + for key, value := range unrecognizedFields { + if _, exists := updatedConfigMap[key]; !exists { + updatedConfigMap[key] = value + } + } + + // Convert back into TOML + b := new(bytes.Buffer) + encoder := toml.NewEncoder(b) + err = encoder.Encode(updatedConfigMap) + if err != nil { + return "", fmt.Errorf("failed to marshal final config: %w", err) + } + + return b.String(), nil +} + func GetDefaultConfig(comment bool) (string, error) { c := config.DefaultCurioConfig() cb, err := config.ConfigUpdate(c, nil, config.Commented(comment), config.DefaultKeepUncommented(), config.NoEnv()) diff --git a/deps/deps_test.go b/deps/deps_test.go new file mode 100644 index 000000000..5cec4b5a4 --- /dev/null +++ b/deps/deps_test.go @@ -0,0 +1,124 @@ +package deps + +import ( + "bytes" + "testing" + + "github.com/BurntSushi/toml" +) + +type ExampleConfig struct { + Subsystems struct { + EnableWindowPost bool + WindowPostMaxTasks int + } + Fees struct { + DefaultMaxFee string + } +} + +// An original TOML configuration that has both recognized and unknown fields. +const originalTOML = ` +[Subsystems] + EnableWindowPost = true + WindowPostMaxTasks = 5 + +[Fees] + DefaultMaxFee = "0.07 FIL" + +[UnknownSection] + SomeUnknownKey = "whatever" + AnotherField = 123 + +[AnotherUnknownSection.Nested] + NestedValue = "I am nested" +` + +func TestExtractAndMergeUnknownFields(t *testing.T) { + //---------------------------------------------------------------------- + // Step 1: Decode original TOML into recognized struct & collect MetaData + //---------------------------------------------------------------------- + var recognized ExampleConfig + meta, err := toml.Decode(originalTOML, &recognized) + if err != nil { + t.Fatalf("failed to decode recognized fields: %v", err) + } + + keys := removeUnknownEntries(meta.Keys(), meta.Undecoded()) + + //---------------------------------------------------------------------- + // Step 2: Extract the unknown fields using extractUnknownFields + //---------------------------------------------------------------------- + unknownFields := extractUnknownFields(keys, originalTOML) + if len(unknownFields) == 0 { + t.Errorf("expected unknown fields, got none") + } + + //---------------------------------------------------------------------- + // Step 3: Update recognized fields in the struct + //---------------------------------------------------------------------- + recognized.Subsystems.EnableWindowPost = false // flip the boolean + recognized.Subsystems.WindowPostMaxTasks = 10 // change from 5 to 10 + recognized.Fees.DefaultMaxFee = "0.08 FIL" // update the fee + + //---------------------------------------------------------------------- + // Step 4: Re-encode recognized fields back to TOML + //---------------------------------------------------------------------- + var buf bytes.Buffer + if err := toml.NewEncoder(&buf).Encode(recognized); err != nil { + t.Fatalf("failed to marshal updated recognized config: %v", err) + } + updatedConfig := buf.String() + + //---------------------------------------------------------------------- + // Step 5: Merge unknown fields back + //---------------------------------------------------------------------- + finalConfig, err := mergeUnknownFields(updatedConfig, unknownFields) + if err != nil { + t.Fatalf("failed to merge unknown fields: %v", err) + } + + //---------------------------------------------------------------------- + // Assertions: Check recognized fields have changed & unknown remain + //---------------------------------------------------------------------- + // 5a. Parse final config into a map to check contents + var finalMap map[string]interface{} + if err := toml.Unmarshal([]byte(finalConfig), &finalMap); err != nil { + t.Fatalf("failed to parse final config: %v\nFinal Config:\n%s", err, finalConfig) + } + + // 5b. Check recognized fields updated + subsystems, ok := finalMap["Subsystems"].(map[string]interface{}) + if !ok { + t.Fatalf("expected 'Subsystems' in final config") + } + + if enable, _ := subsystems["EnableWindowPost"].(bool); enable { + t.Errorf("expected Subsystems.EnableWindowPost = false, got true") + } + if tasks, _ := subsystems["WindowPostMaxTasks"].(int64); tasks != 10 { + t.Errorf("expected Subsystems.WindowPostMaxTasks = 10, got %d", tasks) + } + + fees, ok := finalMap["Fees"].(map[string]interface{}) + if !ok { + t.Fatalf("expected 'Fees' in final config") + } + if defaultFee, _ := fees["DefaultMaxFee"].(string); defaultFee != "0.08 FIL" { + t.Errorf("expected Fees.DefaultMaxFee = '0.08 FIL', got '%s'", defaultFee) + } + + // 5c. Check unknown fields remain + if _, exists := finalMap["UnknownSection"]; !exists { + t.Errorf("expected UnknownSection to remain in final config, but not found") + } + if anotherUnknown, exists := finalMap["AnotherUnknownSection"]; !exists { + t.Errorf("expected AnotherUnknownSection to remain in final config, but not found") + } else { + // Inside nested + nested := anotherUnknown.(map[string]interface{})["Nested"].(map[string]interface{}) + if val, ok := nested["NestedValue"].(string); !ok || val != "I am nested" { + t.Errorf("expected AnotherUnknownSection.Nested.NestedValue = 'I am nested', got '%v'", val) + } + } +}