Skip to content

[EXP] Reader iter api experiments #386

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion adapters/reader.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package adapters

import "github.com/interline-io/transitland-lib/gtfs"
import (
"github.com/interline-io/transitland-lib/gtfs"
)

// Reader is the main interface for reading GTFS data
type Reader interface {
Expand Down
1 change: 0 additions & 1 deletion extract/setterfilter.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ func (tx *SetterFilter) AddValuesFromFile(filename string) error {
tx.AddValue(efn, eid, key, val)
})
return nil

}

// AddValue sets a new value to override.
Expand Down
63 changes: 63 additions & 0 deletions tlcsv/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"fmt"
"io"
"io/ioutil"
"iter"
"net/url"
"os"
"path/filepath"
Expand All @@ -17,6 +18,7 @@ import (
"github.com/interline-io/log"
"github.com/interline-io/transitland-lib/causes"
"github.com/interline-io/transitland-lib/request"
"github.com/interline-io/transitland-lib/tt"
)

// Adapter provides an interface for working with various kinds of GTFS sources: zip, directory, url.
Expand Down Expand Up @@ -284,6 +286,67 @@ func (adapter *ZipAdapter) OpenFile(filename string, cb func(io.Reader)) error {
return nil
}

type canFilename interface {
Filename() string
}

func getFilename(entType any) string {
if v, ok := entType.(canFilename); ok {
return v.Filename()
}
return ""
}

func (adapter *ZipAdapter) ReadEntityRows(entType any) (iter.Seq[tt.Row], func() error) {
var readErr error
errf := func() error { return readErr }
return func(yield func(tt.Row) bool) {
readErr = adapter.OpenFile(getFilename(entType), func(in io.Reader) {
it, errf := ReadRowsIter(in)
for row := range it {
yield(row)
}
_ = errf
})
}, errf
// var readErr error
// errf := func() error { return readErr }
// return func(yield func(tt.Row) bool) {
// filename := ""
// if v, ok := entType.(canFilename); ok {
// filename = v.Filename()
// }
// r, err := zip.OpenReader(adapter.path)
// if err != nil {
// readErr = err
// return
// }
// defer r.Close()
// var inFile *zip.File
// for _, f := range r.File {
// if f.Name != filepath.Join(adapter.internalPrefix, filename) {
// continue
// }
// inFile = f
// }
// if inFile == nil {
// readErr = causes.NewFileNotPresentError(filename)
// return
// }
// in, err := inFile.Open()
// if err != nil {
// return
// }
// it, errf := ReadRowsIter(in)
// for row := range it {
// yield(row)
// }
// readErr = errf()
// in.Close()
// r.Close()
// }, errf
}

// ReadRows opens the specified file and runs the callback on each Row. An error is returned if the file cannot be read.
func (adapter *ZipAdapter) ReadRows(filename string, cb func(Row)) error {
return adapter.OpenFile(filename, func(in io.Reader) {
Expand Down
4 changes: 0 additions & 4 deletions tlcsv/reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -428,10 +428,6 @@ func ReadEntities[T any](reader *Reader, efn string) chan T {
return eout
}

func getFilename(ent tt.Entity) string {
return ent.Filename()
}

// chunkMSI takes a string counter and chunks it into groups of size <= chunkSize
func chunkMSI(count map[string]int, chunkSize int) s2D {
result := s2D{}
Expand Down
50 changes: 31 additions & 19 deletions tlcsv/row.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"strings"

"github.com/dimchansky/utfbom"
"github.com/interline-io/transitland-lib/tt"
)

// Row is a row value with a header.
Expand Down Expand Up @@ -43,10 +44,6 @@ func ReadRows(in io.Reader, cb func(Row)) error {
// Allow unescaped quotes
r.LazyQuotes = true
// Go
return readRows(r, cb)
}

func readRows(r *csv.Reader, cb func(Row)) error {
// Go for it.
firstRow, err := r.Read()
if err != nil {
Expand Down Expand Up @@ -90,7 +87,7 @@ func readRows(r *csv.Reader, cb func(Row)) error {
return nil
}

func ReadRowsIter(in io.Reader, optFns ...csvOptFn) iter.Seq2[Row, error] {
func ReadRowsIter(in io.Reader, optFns ...csvOptFn) (iter.Seq[tt.Row], func() error) {
// Handle byte-order-marks.
r := csv.NewReader(utfbom.SkipOnly(in))
// Allow variable columns - very common in GTFS
Expand All @@ -105,11 +102,14 @@ func ReadRowsIter(in io.Reader, optFns ...csvOptFn) iter.Seq2[Row, error] {
for _, optFn := range optFns {
optFn(r)
}
return func(yield func(Row, error) bool) {
var anyValues []any
var readErr error
errf := func() error { return readErr }
return func(yield func(tt.Row) bool) {
// Go for it.
firstRow, err := r.Read()
if err != nil {
yield(Row{}, err)
firstRow, firstRowErr := r.Read()
if firstRowErr != nil {
readErr = firstRowErr
return
}
// Copy header, since we will reuse the backing array
Expand All @@ -122,34 +122,46 @@ func ReadRowsIter(in io.Reader, optFns ...csvOptFn) iter.Seq2[Row, error] {
for k, i := range header {
hindex[i] = k
}
// Reusable slice
anyValues = make([]any, len(header))
// Read all rows
for {
row, err := r.Read()
if err == nil {
row, rowErr := r.Read()
if rowErr == nil {
// ok
} else if err == io.EOF {
} else if rowErr == io.EOF {
break
} else if _, ok := err.(*csv.ParseError); ok {
} else if _, ok := rowErr.(*csv.ParseError); ok {
// Parse error: clear row, add error to row
row = []string{}
for i := 0; i < len(anyValues); i++ {
anyValues[i] = nil
}
} else {
// Serious error: break and return with error
yield(Row{}, err)
readErr = rowErr
return
}
// Remove whitespace
for i := 0; i < len(row); i++ {
v := row[i]
// This is dumb but saves substantial time.
anyValues[i] = v
if len(v) > 0 && (v[0] == ' ' || v[len(v)-1] == ' ' || v[0] == '\t' || v[len(v)-1] == '\t') {
row[i] = strings.TrimSpace(v)
anyValues[i] = strings.TrimSpace(v)
}
}
// Pass parse errors to row
line, _ := r.FieldPos(0)
cbrow := Row{Row: row, Line: line, Header: header, Hindex: hindex, Err: err}
if !yield(cbrow, nil) {
cbrow := tt.Row{
Values: anyValues,
Line: line,
Header: header,
Hindex: hindex,
Err: rowErr,
}
if !yield(cbrow) {
return
}
}
}
}, errf
}
26 changes: 26 additions & 0 deletions tlcsv/row_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package tlcsv

import (
"fmt"
"testing"

"github.com/interline-io/transitland-lib/gtfs"
"github.com/interline-io/transitland-lib/internal/testpath"
"github.com/interline-io/transitland-lib/tt"
)

func TestReadRowsIter(t *testing.T) {
adapter := &ZipAdapter{path: testpath.RelPath("testdata/example.zip")}
if err := adapter.Open(); err != nil {
t.Error(err)
return
}

it, errf := tt.ReadEntitiesIter[gtfs.Stop](adapter)
for ent := range it {
fmt.Println("ent:", ent.StopName)
}
if err := errf(); err != nil {
t.Error(err)
}
}
139 changes: 139 additions & 0 deletions tt/row.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
package tt

import (
"fmt"
"iter"
"reflect"

"github.com/interline-io/transitland-lib/causes"
"github.com/interline-io/transitland-lib/internal/tags"
"github.com/jmoiron/sqlx/reflectx"
)

const bufferSize = 1_000

var MapperCache = tags.NewCache(reflectx.NewMapperFunc("csv", tags.ToSnakeCase))

type Row struct {
Header []string
Values []any
Hindex map[string]int
Line int
Err error
}

// Get a value from the row.
func (row *Row) Get(k string) (any, bool) {
if i, ok := row.Hindex[k]; ok {
if len(row.Values) > i {
return row.Values[i], true
}
}
return nil, false
}

// Get a value from the row as a string.
func (row *Row) GetString(k string) (string, bool) {
if i, ok := row.Hindex[k]; ok {
if len(row.Values) > i {
return toStrv(row.Values[i]), true
}
}
return "", false
}

type RowReader interface {
ReadEntityRows(any) (iter.Seq[Row], func() error)
}

func ReadEntities[T any](reader RowReader) chan T {
// To get Filename() or TableName()
var entType T
// Prepare channel
eout := make(chan T, bufferSize)
go func(c chan T) {
it, _ := reader.ReadEntityRows(entType)
for row := range it {
var e T
loadRowReflect(&e, row)
c <- e
}
close(c)
}(eout)
return eout
}

func ReadEntitiesIter[T any](reader RowReader) (iter.Seq[T], func() error) {
// To get Filename() or TableName()
var readErr error
var entType *T = new(T)
return func(yield func(T) bool) {
it, errf := reader.ReadEntityRows(entType)
for row := range it {
fmt.Println("row:", row)
var e T
loadRowReflect(&e, row)
yield(e)
}
readErr = errf()
}, func() error { return readErr }
}

// loadRowReflect is the Reflect path
func loadRowReflect(ent any, row Row) []error {
var errs []error
// Get the struct tag map
fmap := MapperCache.GetStructTagMap(ent)
// For each struct tag, set the field value
entValue := reflect.ValueOf(ent).Elem()
for i := 0; i < len(row.Header); i++ {
if i > len(row.Values) {
continue
}
fieldName := row.Header[i]
fieldValue := row.Values[i]
fieldInfo, ok := fmap[fieldName]

// Add to extra fields if there's no struct tag
if !ok {
if extEnt, ok2 := ent.(EntityWithExtra); ok2 {
extEnt.SetExtra(fieldName, toStrv(fieldValue))
}
continue
}

// Skip if empty, special case for strings
if fieldValue == nil {
continue
} else if v, ok := fieldValue.(string); ok && v == "" {
continue
}

// Handle different known types
entFieldAddr := reflectx.FieldByIndexes(entValue, fieldInfo.Index).Addr().Interface()
if v, ok := entFieldAddr.(canScan); ok {
if err := v.Scan(fieldValue); err != nil {
errs = append(errs, err)
}
} else if _, scanErr := convertAssign(entFieldAddr, fieldValue); scanErr != nil {
errs = append(errs, causes.NewFieldParseError(fieldName, toStrv(fieldValue)))
}
}
if len(errs) > 0 {
if extEnt, ok := ent.(EntityWithLoadErrors); ok {
for _, err := range errs {
extEnt.AddError(err)
}
}
}
return errs
}

func toStrv(value any) string {
if v, ok := value.(string); ok {
return v
}
strv := ""
convertAssign(&strv, value)
return strv
}