Skip to content

Commit ad7476b

Browse files
committed
Support for composite foreign keys for belongs_to
1 parent 652894c commit ad7476b

File tree

13 files changed

+626
-180
lines changed

13 files changed

+626
-180
lines changed

Diff for: Earthfile

+1-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ integration-test-base:
6060
apk del .build-dependencies && rm -f msodbcsql*.sig mssql-tools*.apk
6161
ENV PATH="/opt/mssql-tools/bin:${PATH}"
6262

63-
GIT CLONE https://github.com/elixir-ecto/ecto_sql.git /src/ecto_sql
63+
GIT CLONE --branch composite_foreign_keys https://github.com/soundmonster/ecto_sql.git /src/ecto_sql
6464
WORKDIR /src/ecto_sql
6565
RUN mix deps.get
6666

Diff for: integration_test/cases/assoc.exs

+7
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ defmodule Ecto.Integration.AssocTest do
1010
alias Ecto.Integration.PostUser
1111
alias Ecto.Integration.Comment
1212
alias Ecto.Integration.Permalink
13+
alias Ecto.Integration.CompositePk
1314

1415
test "has_many assoc" do
1516
p1 = TestRepo.insert!(%Post{title: "1"})
@@ -750,6 +751,12 @@ defmodule Ecto.Integration.AssocTest do
750751
assert Enum.all?(tree.post.comments, & &1.id)
751752
end
752753

754+
test "inserting struct with associations on composite keys" do
755+
# creates nested belongs_to
756+
%Post{composite: composite} = TestRepo.insert!(%Post{title: "1", composite: %CompositePk{a: 1, b: 2, name: "name"}})
757+
assert %CompositePk{a: 1, b: 2, name: "name"} = composite
758+
end
759+
753760
test "inserting struct with empty associations" do
754761
permalink = TestRepo.insert!(%Permalink{url: "root", post: nil})
755762
assert permalink.post == nil

Diff for: integration_test/cases/repo.exs

+18
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,24 @@ defmodule Ecto.Integration.RepoTest do
152152
assert TestRepo.all(PostUserCompositePk) == []
153153
end
154154

155+
@tag :composite_pk
156+
# TODO this needs a better name
157+
test "insert, update and delete with associated composite pk #2" do
158+
composite = TestRepo.insert!(%CompositePk{a: 1, b: 2, name: "name"})
159+
post = TestRepo.insert!(%Post{title: "post title", composite: composite})
160+
161+
assert post.composite_a == 1
162+
assert post.composite_b == 2
163+
assert TestRepo.get_by!(CompositePk, [a: 1, b: 2]) == composite
164+
165+
post = post |> Ecto.Changeset.change(composite: nil) |> TestRepo.update!
166+
assert is_nil(post.composite_a)
167+
assert is_nil(post.composite_b)
168+
169+
TestRepo.delete!(post)
170+
assert TestRepo.all(CompositePk) == [composite]
171+
end
172+
155173
@tag :invalid_prefix
156174
test "insert, update and delete with invalid prefix" do
157175
post = TestRepo.insert!(%Post{})

Diff for: integration_test/support/schemas.exs

+3
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ defmodule Ecto.Integration.Post do
5454
has_one :update_permalink, Ecto.Integration.Permalink, foreign_key: :post_id, on_delete: :delete_all, on_replace: :update
5555
has_many :comments_authors, through: [:comments, :author]
5656
belongs_to :author, Ecto.Integration.User
57+
belongs_to :composite, Ecto.Integration.CompositePk,
58+
foreign_key: [:composite_a, :composite_b], references: [:a, :b], type: [:integer, :integer], on_replace: :nilify
5759
many_to_many :users, Ecto.Integration.User,
5860
join_through: "posts_users", on_delete: :delete_all, on_replace: :delete
5961
many_to_many :ordered_users, Ecto.Integration.User, join_through: "posts_users", preload_order: [desc: :name]
@@ -291,6 +293,7 @@ defmodule Ecto.Integration.CompositePk do
291293
field :a, :integer, primary_key: true
292294
field :b, :integer, primary_key: true
293295
field :name, :string
296+
has_many :posts, Ecto.Integration.Post, foreign_key: [:composite_a, :composite_b], references: [:a, :b]
294297
end
295298
def changeset(schema, params) do
296299
cast(schema, params, ~w(a b name)a)

Diff for: lib/ecto.ex

+9-4
Original file line numberDiff line numberDiff line change
@@ -510,10 +510,15 @@ defmodule Ecto do
510510
refl = %{owner_key: owner_key} = Ecto.Association.association_from_schema!(schema, assoc)
511511

512512
values =
513-
Enum.uniq for(struct <- structs,
514-
assert_struct!(schema, struct),
515-
key = Map.fetch!(struct, owner_key),
516-
do: key)
513+
structs
514+
|> Enum.filter(&assert_struct!(schema, &1))
515+
|> Enum.map(fn struct ->
516+
owner_key
517+
# TODO remove List.wrap once all assocs use lists
518+
|> List.wrap
519+
|> Enum.map(&Map.fetch!(struct, &1))
520+
end)
521+
|> Enum.uniq
517522

518523
case assocs do
519524
[] ->

Diff for: lib/ecto/association.ex

+99-35
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ defmodule Ecto.Association do
3535
required(:cardinality) => :one | :many,
3636
required(:relationship) => :parent | :child,
3737
required(:owner) => atom,
38-
required(:owner_key) => atom,
38+
required(:owner_key) => list(atom),
3939
required(:field) => atom,
4040
required(:unique) => boolean,
4141
optional(atom) => any}
@@ -71,7 +71,8 @@ defmodule Ecto.Association do
7171
7272
* `:owner` - the owner module of the association
7373
74-
* `:owner_key` - the key in the owner with the association value
74+
* `:owner_key` - the key in the owner with the association value, or a
75+
list of keys for composite keys
7576
7677
* `:relationship` - if the relationship to the specified schema is
7778
of a `:child` or a `:parent`
@@ -235,8 +236,15 @@ defmodule Ecto.Association do
235236
# for the final WHERE clause with values.
236237
{_, query, _, dest_out_key} = Enum.reduce(joins, {source, query, counter, source.out_key}, fn curr_rel, {prev_rel, query, counter, _} ->
237238
related_queryable = curr_rel.schema
238-
239-
next = join(query, :inner, [{src, counter}], dest in ^related_queryable, on: field(src, ^prev_rel.out_key) == field(dest, ^curr_rel.in_key))
239+
# TODO remove this once all relations store keys in lists
240+
in_keys = List.wrap(curr_rel.in_key)
241+
out_keys = List.wrap(prev_rel.out_key)
242+
next = query
243+
# join on the first field of the foreign key
244+
|> join(:inner, [{src, counter}], dest in ^related_queryable, on: field(src, ^hd(out_keys)) == field(dest, ^hd(in_keys)))
245+
# add the rest of the foreign key fields, if any
246+
|> composite_joins_query(counter, counter + 1, tl(out_keys), tl(in_keys))
247+
# consider where clauses on assocs
240248
|> combine_joins_query(curr_rel.where, counter + 1)
241249

242250
{curr_rel, next, counter + 1, curr_rel.out_key}
@@ -320,6 +328,16 @@ defmodule Ecto.Association do
320328
end)
321329
end
322330

331+
# TODO docs
332+
def composite_joins_query(query, _binding_src, _binding_dst, [], []) do
333+
query
334+
end
335+
def composite_joins_query(query, binding_src, binding_dst, [src_key | src_keys], [dst_key | dst_keys]) do
336+
# TODO
337+
[query, binding_src, binding_dst, [src_key | src_keys], [dst_key | dst_keys]] |> IO.inspect(label: :composite_joins_query)
338+
query
339+
end
340+
323341
@doc """
324342
Add the default assoc query where clauses to a join.
325343
@@ -335,6 +353,16 @@ defmodule Ecto.Association do
335353
%{query | joins: joins ++ [%{join_expr | on: %{join_on | expr: expr, params: params}}]}
336354
end
337355

356+
# TODO docs
357+
def composite_assoc_query(query, _binding_src, [], []) do
358+
query
359+
end
360+
def composite_assoc_query(query, binding_dst, [dst_key | dst_keys], [value | values]) do
361+
# TODO
362+
[query, binding_dst, [dst_key | dst_keys], [value | values]] |> IO.inspect(label: :composite_assoc_query)
363+
query
364+
end
365+
338366
@doc """
339367
Add the default assoc query where clauses a provided query.
340368
"""
@@ -632,6 +660,10 @@ defmodule Ecto.Association do
632660

633661
defp primary_key!(nil), do: []
634662
defp primary_key!(struct), do: Ecto.primary_key!(struct)
663+
664+
def missing_fields(queryable, related_key) do
665+
Enum.filter related_key, &is_nil(queryable.__schema__(:type, &1))
666+
end
635667
end
636668

637669
defmodule Ecto.Association.Has do
@@ -644,8 +676,8 @@ defmodule Ecto.Association.Has do
644676
* `field` - The name of the association field on the schema
645677
* `owner` - The schema where the association was defined
646678
* `related` - The schema that is associated
647-
* `owner_key` - The key on the `owner` schema used for the association
648-
* `related_key` - The key on the `related` schema used for the association
679+
* `owner_key` - The list of columns that form the key on the `owner` schema used for the association
680+
* `related_key` - The list of columns that form the key on the `related` schema used for the association
649681
* `queryable` - The real query to use for querying association
650682
* `on_delete` - The action taken on associations when schema is deleted
651683
* `on_replace` - The action taken on associations when schema is replaced
@@ -673,8 +705,8 @@ defmodule Ecto.Association.Has do
673705
{:error, "associated schema #{inspect queryable} does not exist"}
674706
not function_exported?(queryable, :__schema__, 2) ->
675707
{:error, "associated module #{inspect queryable} is not an Ecto schema"}
676-
is_nil queryable.__schema__(:type, related_key) ->
677-
{:error, "associated schema #{inspect queryable} does not have field `#{related_key}`"}
708+
[] != (missing_fields = Ecto.Association.missing_fields(queryable, related_key)) ->
709+
{:error, "associated schema #{inspect queryable} does not have field(s) `#{inspect missing_fields}`"}
678710
true ->
679711
:ok
680712
end
@@ -686,14 +718,17 @@ defmodule Ecto.Association.Has do
686718
cardinality = Keyword.fetch!(opts, :cardinality)
687719
related = Ecto.Association.related_from_query(queryable, name)
688720

689-
ref =
721+
refs =
690722
module
691723
|> Module.get_attribute(:primary_key)
692724
|> get_ref(opts[:references], name)
725+
|> List.wrap()
693726

694-
unless Module.get_attribute(module, :ecto_fields)[ref] do
695-
raise ArgumentError, "schema does not have the field #{inspect ref} used by " <>
696-
"association #{inspect name}, please set the :references option accordingly"
727+
for ref <- refs do
728+
unless Module.get_attribute(module, :ecto_fields)[ref] do
729+
raise ArgumentError, "schema does not have the field #{inspect ref} used by " <>
730+
"association #{inspect name}, please set the :references option accordingly"
731+
end
697732
end
698733

699734
if opts[:through] do
@@ -725,13 +760,19 @@ defmodule Ecto.Association.Has do
725760
raise ArgumentError, "expected `:where` for #{inspect name} to be a keyword list, got: `#{inspect where}`"
726761
end
727762

763+
foreign_key = case opts[:foreign_key] do
764+
nil -> Enum.map(refs, &Ecto.Association.association_key(module, &1))
765+
key when is_atom(key) -> [key]
766+
keys when is_list(keys) -> keys
767+
end
768+
728769
%__MODULE__{
729770
field: name,
730771
cardinality: cardinality,
731772
owner: module,
732773
related: related,
733-
owner_key: ref,
734-
related_key: opts[:foreign_key] || Ecto.Association.association_key(module, ref),
774+
owner_key: refs,
775+
related_key: foreign_key,
735776
queryable: queryable,
736777
on_delete: on_delete,
737778
on_replace: on_replace,
@@ -756,19 +797,23 @@ defmodule Ecto.Association.Has do
756797

757798
@impl true
758799
def joins_query(%{related_key: related_key, owner: owner, owner_key: owner_key, queryable: queryable} = assoc) do
759-
from(o in owner, join: q in ^queryable, on: field(q, ^related_key) == field(o, ^owner_key))
800+
# TODO find out how to handle a dynamic list of fields here
801+
from(o in owner, join: q in ^queryable, on: field(q, ^hd(related_key)) == field(o, ^hd(owner_key)))
802+
|> Ecto.Association.composite_joins_query(0, 1, tl(related_key), tl(owner_key))
760803
|> Ecto.Association.combine_joins_query(assoc.where, 1)
761804
end
762805

763806
@impl true
764807
def assoc_query(%{related_key: related_key, queryable: queryable} = assoc, query, [value]) do
765-
from(x in (query || queryable), where: field(x, ^related_key) == ^value)
808+
from(x in (query || queryable), where: field(x, ^hd(related_key)) == ^hd(value))
809+
|> Ecto.Association.composite_assoc_query(0, tl(related_key), tl(value))
766810
|> Ecto.Association.combine_assoc_query(assoc.where)
767811
end
768812

769813
@impl true
770814
def assoc_query(%{related_key: related_key, queryable: queryable} = assoc, query, values) do
771-
from(x in (query || queryable), where: field(x, ^related_key) in ^values)
815+
from(x in (query || queryable), where: field(x, ^hd(related_key)) in ^Enum.map(values, &hd/1))
816+
|> Ecto.Association.composite_assoc_query(0, tl(related_key), Enum.map(values, &tl/1))
772817
|> Ecto.Association.combine_assoc_query(assoc.where)
773818
end
774819

@@ -807,16 +852,21 @@ defmodule Ecto.Association.Has do
807852
%{data: parent, repo: repo} = parent_changeset
808853
%{action: action, changes: changes} = changeset
809854

810-
{key, value} = parent_key(assoc, parent)
811-
changeset = update_parent_key(changeset, action, key, value)
812-
changeset = Ecto.Association.update_parent_prefix(changeset, parent)
855+
parent_keys = parent_keys(assoc, parent)
856+
changeset = Enum.reduce parent_keys, changeset, fn {key, value}, changeset ->
857+
changeset = update_parent_key(changeset, action, key, value)
858+
Ecto.Association.update_parent_prefix(changeset, parent)
859+
end
813860

814861
case apply(repo, action, [changeset, opts]) do
815862
{:ok, _} = ok ->
816863
if action == :delete, do: {:ok, nil}, else: ok
817864
{:error, changeset} ->
818-
original = Map.get(changes, key)
819-
{:error, put_in(changeset.changes[key], original)}
865+
changeset = Enum.reduce parent_keys, changeset, fn {key, _}, changeset ->
866+
original = Map.get(changes, key)
867+
put_in(changeset.changes[key], original)
868+
end
869+
{:error, changeset}
820870
end
821871
end
822872

@@ -825,11 +875,21 @@ defmodule Ecto.Association.Has do
825875
defp update_parent_key(changeset, _action, key, value),
826876
do: Ecto.Changeset.put_change(changeset, key, value)
827877

828-
defp parent_key(%{related_key: related_key}, nil) do
829-
{related_key, nil}
878+
defp parent_keys(%{related_key: related_keys}, nil) when is_list(related_keys) do
879+
Enum.map related_keys, fn related_key -> {related_key, nil} end
880+
end
881+
defp parent_keys(%{related_key: related_key}, nil) do
882+
[{related_key, nil}]
883+
end
884+
defp parent_keys(%{owner_key: owner_keys, related_key: related_keys}, owner) when is_list(owner_keys) and is_list(related_keys) do
885+
owner_keys
886+
|> Enum.zip(related_keys)
887+
|> Enum.map(fn {owner_key, related_key} ->
888+
{related_key, Map.get(owner, owner_key)}
889+
end)
830890
end
831-
defp parent_key(%{owner_key: owner_key, related_key: related_key}, owner) do
832-
{related_key, Map.get(owner, owner_key)}
891+
defp parent_keys(%{owner_key: owner_key, related_key: related_key}, owner) do
892+
[{related_key, Map.get(owner, owner_key)}]
833893
end
834894

835895
## Relation callbacks
@@ -982,16 +1042,16 @@ defmodule Ecto.Association.BelongsTo do
9821042
{:error, "associated schema #{inspect queryable} does not exist"}
9831043
not function_exported?(queryable, :__schema__, 2) ->
9841044
{:error, "associated module #{inspect queryable} is not an Ecto schema"}
985-
is_nil queryable.__schema__(:type, related_key) ->
986-
{:error, "associated schema #{inspect queryable} does not have field `#{related_key}`"}
1045+
[] != (missing_fields = Ecto.Association.missing_fields(queryable, related_key)) ->
1046+
{:error, "associated schema #{inspect queryable} does not have field(s) `#{inspect missing_fields}`"}
9871047
true ->
9881048
:ok
9891049
end
9901050
end
9911051

9921052
@impl true
9931053
def struct(module, name, opts) do
994-
ref = if ref = opts[:references], do: ref, else: :id
1054+
refs = if ref = opts[:references], do: List.wrap(ref), else: [:id]
9951055
queryable = Keyword.fetch!(opts, :queryable)
9961056
related = Ecto.Association.related_from_query(queryable, name)
9971057
on_replace = Keyword.get(opts, :on_replace, :raise)
@@ -1013,8 +1073,8 @@ defmodule Ecto.Association.BelongsTo do
10131073
field: name,
10141074
owner: module,
10151075
related: related,
1016-
owner_key: Keyword.fetch!(opts, :foreign_key),
1017-
related_key: ref,
1076+
owner_key: List.wrap(Keyword.fetch!(opts, :foreign_key)),
1077+
related_key: refs,
10181078
queryable: queryable,
10191079
on_replace: on_replace,
10201080
defaults: defaults,
@@ -1031,19 +1091,22 @@ defmodule Ecto.Association.BelongsTo do
10311091

10321092
@impl true
10331093
def joins_query(%{related_key: related_key, owner: owner, owner_key: owner_key, queryable: queryable} = assoc) do
1034-
from(o in owner, join: q in ^queryable, on: field(q, ^related_key) == field(o, ^owner_key))
1094+
from(o in owner, join: q in ^queryable, on: field(q, ^hd(related_key)) == field(o, ^hd(owner_key)))
1095+
|> Ecto.Association.composite_joins_query(0, 1, tl(related_key), tl(owner_key))
10351096
|> Ecto.Association.combine_joins_query(assoc.where, 1)
10361097
end
10371098

10381099
@impl true
10391100
def assoc_query(%{related_key: related_key, queryable: queryable} = assoc, query, [value]) do
1040-
from(x in (query || queryable), where: field(x, ^related_key) == ^value)
1101+
from(x in (query || queryable), where: field(x, ^hd(related_key)) == ^hd(value))
1102+
|> Ecto.Association.composite_assoc_query(0, tl(related_key), tl(value))
10411103
|> Ecto.Association.combine_assoc_query(assoc.where)
10421104
end
10431105

10441106
@impl true
10451107
def assoc_query(%{related_key: related_key, queryable: queryable} = assoc, query, values) do
1046-
from(x in (query || queryable), where: field(x, ^related_key) in ^values)
1108+
from(x in (query || queryable), where: field(x, ^hd(related_key)) in ^Enum.map(values, &hd/1))
1109+
|> Ecto.Association.composite_assoc_query(0, tl(related_key), Enum.map(values, &tl/1))
10471110
|> Ecto.Association.combine_assoc_query(assoc.where)
10481111
end
10491112

@@ -1264,11 +1327,12 @@ defmodule Ecto.Association.ManyToMany do
12641327

12651328
owner_key_type = owner.__schema__(:type, owner_key)
12661329

1330+
# TODO fix the hd(values)
12671331
# We only need to join in the "join table". Preload and Ecto.assoc expressions can then filter
12681332
# by &1.join_owner_key in ^... to filter down to the associated entries in the related table.
12691333
from(q in (query || queryable),
12701334
join: j in ^join_through, on: field(q, ^related_key) == field(j, ^join_related_key),
1271-
where: field(j, ^join_owner_key) in type(^values, {:in, ^owner_key_type})
1335+
where: field(j, ^join_owner_key) in type(^hd(values), {:in, ^owner_key_type})
12721336
)
12731337
|> Ecto.Association.combine_assoc_query(assoc.where)
12741338
|> Ecto.Association.combine_joins_query(assoc.join_where, 1)

0 commit comments

Comments
 (0)