@@ -23,6 +23,7 @@ package graph
2323import (
2424 "fmt"
2525 "io"
26+ "slices"
2627 "sort"
2728)
2829
@@ -60,9 +61,10 @@ type dag[T any] struct {
6061}
6162
6263type vertex [T any ] struct {
63- edges []* vertex [T ]
64- name string
65- object T
64+ edges []* vertex [T ]
65+ parents []string
66+ name string
67+ object T
6668}
6769
6870func (v * vertex [T ]) IsRoot () bool {
@@ -79,7 +81,7 @@ func (d *dag[T]) Add(name string, object T, dependencies ...string) error {
7981 vert , ok := d .placeholders [name ]
8082 delete (d .placeholders , name )
8183 if ! ok {
82- vert = & vertex [T ]{name : name , object : object }
84+ vert = & vertex [T ]{name : name , object : object , parents : dependencies }
8385 } else {
8486 // set object on placeholder
8587 vert .object = object
@@ -107,23 +109,62 @@ func (d *dag[T]) Add(name string, object T, dependencies ...string) error {
107109 return nil
108110}
109111
112+ // leaves handle edge cases where there are more then one leaf, and from is a subset of the leaves not in the from
113+ func (d * dag [T ]) leaves (from []string ) []string {
114+ leaves := make ([]string , 0 )
115+ for _ , f := range d .leafs {
116+ leaves = append (leaves , f .name )
117+ }
118+ if len (from ) > len (leaves ) {
119+ return nil
120+ }
121+ dif := diff (leaves , from )
122+ return dif
123+ }
124+
125+ // diff returns the elements in a that are not in b
126+ func diff (a , b []string ) []string {
127+ ret := make ([]string , 0 )
128+ for _ , v := range a {
129+ if ! slices .Contains (b , v ) {
130+ ret = append (ret , v )
131+ }
132+ }
133+ return ret
134+ }
135+
110136func (d * dag [T ]) Next (from ... string ) ([]string , error ) {
111137 if err := d .Valid (); err != nil {
112138 return nil , err
113139 }
114140
115- if len (from ) == 0 { // base staring case
141+ if len (from ) == 0 { // base starting case
116142 return getNames (flat (d .leafs )), nil
117143 }
144+ leaves := d .leaves (from )
145+ if len (leaves ) > 0 {
146+ return leaves , nil
147+ }
118148
119149 // Use a map to deduplicate edges
120150 seen := make (map [string ]* vertex [T ])
121151 for _ , f := range from {
122152 vert := d .vertices [f ]
123- for _ , edge := range d .vertices [f ].edges {
124- // search the path before adding it, if the path is longer than 1 then we don't want to add it
125- // this aids in walking the graph later because erroneous paths have been pre pruned
126- if len (d .find_longest_path (vert .name , edge .name )) <= 1 {
153+ for _ , edge := range vert .edges {
154+ // Skip if already processed
155+ if slices .Contains (from , edge .name ) {
156+ continue
157+ }
158+
159+ // Check if all parents are in the completed set
160+ allParentsSatisfied := true
161+ for _ , parent := range edge .parents {
162+ if ! slices .Contains (from , parent ) {
163+ allParentsSatisfied = false
164+ break
165+ }
166+ }
167+ if allParentsSatisfied {
127168 seen [edge .name ] = edge
128169 }
129170 }
@@ -135,22 +176,7 @@ func (d *dag[T]) Next(from ...string) ([]string, error) {
135176 root = append (root , v )
136177 }
137178
138- // remove matching from input
139- in := make (map [string ]struct {})
140- for _ , f := range from {
141- in [f ] = struct {}{}
142- }
143-
144- temp := root [:0 ]
145- for _ , out := range root {
146- if _ , ok := in [out .name ]; ! ok {
147- temp = append (temp , out )
148- }
149- }
150- root = temp
151-
152179 sortEdges (root )
153-
154180 return getNames (root ), nil
155181}
156182
@@ -197,44 +223,18 @@ func getNames[T any](vs []*vertex[T]) []string {
197223 return ret
198224}
199225
200- // search dag for longest common path
201- func (d * dag [T ]) find_longest_path (start , end string ) []string {
202-
203- from := d .vertices [start ]
204- if from .IsRoot () {
205- return nil
206- }
207-
208- paths := make (map [string ][]string )
209- for _ , vert := range from .edges {
210- if vert .name == end {
211- paths [vert .name ] = []string {end }
212- }
213- path := d .find_longest_path (vert .name , end )
214- if path != nil {
215- paths [vert .name ] = append ([]string {vert .name }, path ... )
216- }
217- }
218- var ret []string
219- for _ , v := range paths {
220- if len (v ) > len (ret ) {
221- ret = v
222- }
223- }
224- return ret
225- }
226-
227226// PrintGraph is helper function for print out a DependencyGraph graph,
228227// or any thing that implements the Walk interface
229228func PrintGraph [T any ](out io.Writer , d Walk [T ]) error {
230- step := make ([]string , 0 )
231229 ret := make ([][]string , 0 )
230+ completed := make ([]string , 0 )
232231 for {
233- step , _ = d .Next (step ... )
232+ step , _ : = d .Next (completed ... )
234233 if len (step ) == 0 {
235- break // end
234+ break // end of graph
236235 }
237236 ret = append (ret , step )
237+ completed = append (completed , step ... )
238238 }
239239
240240 for i := range ret {
0 commit comments