Skip to content

Commit de3a957

Browse files
authored
fix: hoist validation and transformation to top of call chain. before, set_docs would be called on the un-transformed sections (#221)
1 parent 9bb40bc commit de3a957

File tree

3 files changed

+134
-94
lines changed

3 files changed

+134
-94
lines changed

lib/spark/dsl/entity.ex

Lines changed: 9 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -249,22 +249,20 @@ defmodule Spark.Dsl.Entity do
249249

250250
@doc false
251251
def build(
252-
entity,
252+
%{
253+
target: target,
254+
schema: schema,
255+
auto_set_fields: auto_set_fields,
256+
transform: transform,
257+
identifier: identifier,
258+
singleton_entity_keys: singleton_entity_keys,
259+
entities: nested_entity_definitions
260+
},
253261
opts,
254262
nested_entities,
255263
anno,
256264
opts_anno
257265
) do
258-
%{
259-
target: target,
260-
schema: schema,
261-
auto_set_fields: auto_set_fields,
262-
transform: transform,
263-
identifier: identifier,
264-
singleton_entity_keys: singleton_entity_keys,
265-
entities: nested_entity_definitions
266-
} = validate_and_transform(entity, [], nil)
267-
268266
with {:ok, opts, more_nested_entities} <-
269267
fetch_single_argument_entities_from_opts(
270268
opts,
@@ -458,56 +456,4 @@ defmodule Spark.Dsl.Entity do
458456

459457
nil
460458
end
461-
462-
@doc """
463-
Validates and transforms an entity structure, ensuring nested entities are properly formatted.
464-
465-
This function recursively processes a DSL entity and its nested entities, converting
466-
single entity values to lists where needed and validating the structure.
467-
468-
## Parameters
469-
470-
- `entity` - The entity to validate and transform
471-
- `path` - The current path in the DSL structure (for error reporting)
472-
- `module` - The module context (for error reporting)
473-
474-
## Returns
475-
476-
Returns the transformed entity with normalized nested entity structures.
477-
"""
478-
def validate_and_transform(entity, path \\ [], module \\ nil)
479-
480-
def validate_and_transform(%Spark.Dsl.Entity{} = entity, path, module) do
481-
# Include the entity's name in the path when processing nested entities
482-
nested_path = if entity.name, do: path ++ [entity.name], else: path
483-
484-
entities =
485-
entity.entities
486-
|> List.wrap()
487-
|> Enum.map(fn
488-
{key, %Spark.Dsl.Entity{} = value} ->
489-
{key, [validate_and_transform(value, nested_path ++ [key], module)]}
490-
491-
{key, values} when is_list(values) ->
492-
# Already a list, keep as is
493-
{key, Enum.map(values, &validate_and_transform(&1, nested_path ++ [key], module))}
494-
495-
{key, value} ->
496-
# Non-entity, non-list value - this is invalid
497-
raise Spark.Error.DslError,
498-
module: module,
499-
path: nested_path ++ [key],
500-
message:
501-
"nested entity '#{key}' must be an entity or list of entities, got: #{inspect(value)}"
502-
end)
503-
504-
%{entity | entities: entities}
505-
end
506-
507-
def validate_and_transform(_, path, module) do
508-
raise Spark.Error.DslError,
509-
module: module,
510-
path: path,
511-
message: "Invalid entity structure"
512-
end
513459
end

lib/spark/dsl/extension.ex

Lines changed: 83 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,26 @@ defmodule Spark.Dsl.Extension do
381381
end
382382
end
383383

384+
def set_docs(items) when is_list(items) do
385+
Enum.map(items, &set_docs/1)
386+
end
387+
388+
def set_docs(%Spark.Dsl.Entity{} = entity) do
389+
entity
390+
|> Map.put(:docs, Spark.Dsl.Extension.doc_entity(entity))
391+
|> Map.put(
392+
:entities,
393+
Enum.map(entity.entities || [], fn {key, value} -> {key, set_docs(value)} end)
394+
)
395+
end
396+
397+
def set_docs(%Spark.Dsl.Section{} = section) do
398+
section
399+
|> Map.put(:entities, set_docs(section.entities))
400+
|> Map.put(:sections, set_docs(section.sections))
401+
|> Map.put(:docs, Spark.Dsl.Extension.doc_section(section))
402+
end
403+
384404
@doc false
385405
defmacro __using__(opts) do
386406
quote bind_quoted: [
@@ -397,19 +417,23 @@ defmodule Spark.Dsl.Extension do
397417
alias Spark.Dsl.Extension
398418
module_prefix = module_prefix || __MODULE__
399419

400-
@behaviour Extension
401-
Extension.build(__MODULE__, module_prefix, sections, dsl_patches)
402420
@_sections sections
421+
|> Enum.map(&Extension.validate_and_transform_section(&1, __MODULE__))
422+
|> Extension.set_docs()
423+
@_dsl_patches Extension.validate_and_transform_dsl_patches(dsl_patches, __MODULE__)
403424
@_transformers transformers
404425
@_verifiers verifiers
405426
@_persisters persisters
406-
@_dsl_patches dsl_patches
407427
@_imports imports
408428
@_add_extensions add_extensions
429+
430+
@behaviour Extension
431+
Extension.build(__MODULE__, module_prefix, @_sections, @_dsl_patches)
432+
409433
@after_verify Spark.Dsl.Extension
410434

411435
@doc false
412-
def sections, do: set_docs(@_sections)
436+
def sections, do: @_sections
413437
@doc false
414438
def verifiers, do: [Spark.Dsl.Verifiers.VerifyEntityUniqueness | @_verifiers]
415439
@doc false
@@ -424,26 +448,6 @@ defmodule Spark.Dsl.Extension do
424448
def dsl_patches, do: @_dsl_patches
425449
@doc false
426450
def add_extensions, do: @_add_extensions
427-
428-
defp set_docs(items) when is_list(items) do
429-
Enum.map(items, &set_docs/1)
430-
end
431-
432-
defp set_docs(%Spark.Dsl.Entity{} = entity) do
433-
entity
434-
|> Map.put(:docs, Spark.Dsl.Extension.doc_entity(entity))
435-
|> Map.put(
436-
:entities,
437-
Enum.map(entity.entities || [], fn {key, value} -> {key, set_docs(value)} end)
438-
)
439-
end
440-
441-
defp set_docs(%Spark.Dsl.Section{} = section) do
442-
section
443-
|> Map.put(:entities, set_docs(section.entities))
444-
|> Map.put(:sections, set_docs(section.sections))
445-
|> Map.put(:docs, Spark.Dsl.Extension.doc_section(section))
446-
end
447451
end
448452
end
449453

@@ -834,7 +838,7 @@ defmodule Spark.Dsl.Extension do
834838
%Spark.Dsl.Patch.AddEntity{entity: entity} = dsl_patch,
835839
module
836840
) do
837-
entity = Spark.Dsl.Entity.validate_and_transform(entity, [], module)
841+
entity = validate_and_transform_entity(entity, [], module)
838842
%{dsl_patch | entity: entity}
839843
end
840844

@@ -855,8 +859,6 @@ defmodule Spark.Dsl.Extension do
855859
] do
856860
alias Spark.Dsl.Extension
857861

858-
dsl_patches = Extension.validate_and_transform_dsl_patches(dsl_patches, __MODULE__)
859-
860862
{:ok, agent} = Agent.start_link(fn -> [] end)
861863
agent_and_pid = {agent, self()}
862864

@@ -904,11 +906,63 @@ defmodule Spark.Dsl.Extension do
904906
) do
905907
%{
906908
section
907-
| entities:
908-
Enum.map(entities, &Spark.Dsl.Entity.validate_and_transform(&1, [section.name], module))
909+
| entities: Enum.map(entities, &validate_and_transform_entity(&1, [section.name], module))
909910
}
910911
end
911912

913+
@doc """
914+
Validates and transforms an entity structure, ensuring nested entities are properly formatted.
915+
916+
This function recursively processes a DSL entity and its nested entities, converting
917+
single entity values to lists where needed and validating the structure.
918+
919+
## Parameters
920+
921+
- `entity` - The entity to validate and transform
922+
- `path` - The current path in the DSL structure (for error reporting)
923+
- `module` - The module context (for error reporting)
924+
925+
## Returns
926+
927+
Returns the transformed entity with normalized nested entity structures.
928+
"""
929+
def validate_and_transform_entity(entity, path \\ [], module \\ nil)
930+
931+
def validate_and_transform_entity(%Spark.Dsl.Entity{} = entity, path, module) do
932+
# Include the entity's name in the path when processing nested entities
933+
nested_path = if entity.name, do: path ++ [entity.name], else: path
934+
935+
entities =
936+
entity.entities
937+
|> List.wrap()
938+
|> Enum.map(fn
939+
{key, %Spark.Dsl.Entity{} = value} ->
940+
{key, [validate_and_transform_entity(value, nested_path ++ [key], module)]}
941+
942+
{key, values} when is_list(values) ->
943+
# Already a list, keep as is
944+
{key,
945+
Enum.map(values, &validate_and_transform_entity(&1, nested_path ++ [key], module))}
946+
947+
{key, value} ->
948+
# Non-entity, non-list value - this is invalid
949+
raise Spark.Error.DslError,
950+
module: module,
951+
path: nested_path ++ [key],
952+
message:
953+
"nested entity '#{key}' must be an entity or list of entities, got: #{inspect(value)}"
954+
end)
955+
956+
%{entity | entities: entities}
957+
end
958+
959+
def validate_and_transform_entity(_, path, module) do
960+
raise Spark.Error.DslError,
961+
module: module,
962+
path: path,
963+
message: "Invalid entity structure"
964+
end
965+
912966
@doc false
913967
defmacro build_section(
914968
agent,
@@ -927,8 +981,6 @@ defmodule Spark.Dsl.Extension do
927981
generated: true do
928982
alias Spark.Dsl
929983

930-
section = Dsl.Extension.validate_and_transform_section(section, __MODULE__)
931-
932984
{section_modules, entity_modules, opts_module} =
933985
Dsl.Extension.do_build_section(
934986
agent,

test/dsl_validation_test.exs

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,5 +84,47 @@ defmodule Spark.DslValidationTest do
8484
use Spark.Dsl, default_extensions: [extensions: Dsl]
8585
end
8686
end
87+
88+
test "dsl patches with nested entity definitions are properly transformed" do
89+
defmodule DslPatchWithNestedEntities do
90+
@moduledoc false
91+
92+
defmodule PatcherExtension do
93+
@moduledoc false
94+
95+
@nested_entity %Spark.Dsl.Entity{
96+
name: :nested_item,
97+
target: Spark.Test.Step,
98+
schema: []
99+
}
100+
101+
@another_nested %Spark.Dsl.Entity{
102+
name: :another_nested_item,
103+
target: Spark.Test.Step,
104+
schema: []
105+
}
106+
107+
@patched_entity %Spark.Dsl.Entity{
108+
name: :patched_entity,
109+
target: Spark.Test.Step,
110+
schema: [],
111+
# Test both single entity and list formats in nested entities
112+
entities: [
113+
single: @nested_entity,
114+
multiple: [@another_nested]
115+
]
116+
}
117+
118+
@patch %Spark.Dsl.Patch.AddEntity{
119+
section_path: [:steps],
120+
entity: @patched_entity
121+
}
122+
123+
use Spark.Dsl.Extension, dsl_patches: [@patch]
124+
end
125+
126+
use Spark.Dsl, default_extensions: [extensions: [Spark.Test.Extension, PatcherExtension]]
127+
end
128+
end
87129
end
88130
end

0 commit comments

Comments
 (0)