Skip to content

Commit cde09bc

Browse files
committed
feat(dataframe): SeriesGroupBy has apply method, returns new Series
1 parent 666fcb0 commit cde09bc

File tree

6 files changed

+97
-2
lines changed

6 files changed

+97
-2
lines changed

dataframe/index.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,10 @@ func (i *Index) Type() string {
7878
func (i *Index) Attr(name string) (starlark.Value, error) {
7979
switch name {
8080
case "name":
81+
if i == nil {
82+
// TODO(dustmop): Add a test that covers this
83+
return starlark.None, nil
84+
}
8185
return starlark.String(i.name), nil
8286
case "str":
8387
return &stringMethods{subject: i}, nil

dataframe/series.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -450,6 +450,16 @@ func builtinAttrNames(methods map[string]*starlark.Builtin) []string {
450450
return names
451451
}
452452

453+
// name returns of the name of the series
454+
func seriesAttrName(self *Series) (starlark.Value, error) {
455+
return starlark.String(self.name), nil
456+
}
457+
458+
// size returns the number of elements in the series
459+
func seriesAttrSize(self *Series) (starlark.Value, error) {
460+
return starlark.MakeInt(self.Len()), nil
461+
}
462+
453463
func seriesGet(_ *starlark.Thread, b *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) {
454464
var key starlark.Value
455465
if err := starlark.UnpackPositionalArgs(b.Name(), args, kwargs, 1, &key); err != nil {
@@ -746,6 +756,14 @@ func newSeriesFromFloats(vals []float64, index *Index, name string) *Series {
746756
}
747757
}
748758

759+
func newSeriesFromStrings(texts []string, index *Index, name string) *Series {
760+
results := make([]interface{}, len(texts))
761+
for i, txt := range texts {
762+
results[i] = txt
763+
}
764+
return newSeriesFromObjects(results, index, name)
765+
}
766+
749767
func newSeriesFromObjects(vals []interface{}, index *Index, name string) *Series {
750768
return &Series{
751769
dtype: "object",

dataframe/series_all_methods.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@ var seriesAttributes = map[string]seriesAttrImpl{
2424
"is_monotonic_increasing": attrNoImplSeries("is_monotonic_increasing"),
2525
"is_unique": attrNoImplSeries("is_unique"),
2626
"loc": attrNoImplSeries("loc"),
27-
"name": attrNoImplSeries("name"),
27+
"name": seriesAttrName,
2828
"nbytes": attrNoImplSeries("nbytes"),
2929
"ndim": attrNoImplSeries("ndim"),
3030
"shape": attrNoImplSeries("shape"),
31-
"size": attrNoImplSeries("size"),
31+
"size": seriesAttrSize,
3232
"values": attrNoImplSeries("values"),
3333
}
3434

dataframe/series_groupby_result.go

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
type SeriesGroupByResult struct {
1313
lhsLabel string
1414
rhsLabel string
15+
// TODO(dustmop): convert to map[string]Series
1516
grouping map[string][]string
1617
}
1718

@@ -24,6 +25,7 @@ var (
2425
var seriesGroupByResultMethods = map[string]*starlark.Builtin{
2526
"count": starlark.NewBuiltin("count", seriesGroupByResultCount),
2627
"sum": starlark.NewBuiltin("sum", seriesGroupByResultSum),
28+
"apply": starlark.NewBuiltin("apply", seriesGroupByResultApply),
2729
}
2830

2931
// Freeze has no effect on the immutable SeriesGroupByResult
@@ -113,6 +115,55 @@ func seriesGroupByResultCount(_ *starlark.Thread, b *starlark.Builtin, args star
113115
return newSeriesFromInts(vals, index, self.rhsLabel), nil
114116
}
115117

118+
// apply method returns a Series that is built by calling the given
119+
// function, and passing each grouped series as an argument to it
120+
func seriesGroupByResultApply(thread *starlark.Thread, b *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) {
121+
var (
122+
funcVal starlark.Value
123+
self = b.Receiver().(*SeriesGroupByResult)
124+
)
125+
126+
if err := starlark.UnpackArgs("apply", args, kwargs,
127+
"function", &funcVal,
128+
); err != nil {
129+
return nil, err
130+
}
131+
132+
funcObj, ok := funcVal.(*starlark.Function)
133+
if !ok {
134+
return nil, fmt.Errorf("first argument must be a function")
135+
}
136+
137+
sortedKeys := getSortedKeys(self.grouping)
138+
builder := newTypedSliceBuilder(len(sortedKeys))
139+
indexNames := make([]string, len(sortedKeys))
140+
141+
for i, groupName := range sortedKeys {
142+
values := self.grouping[groupName]
143+
// TODO(dustmop): Pass actual index here
144+
index := NewIndex(nil, groupName)
145+
series := newSeriesFromStrings(values, index, groupName)
146+
arguments := starlark.Tuple{series}
147+
// Call function, passing the series to it
148+
res, err := starlark.Call(thread, funcObj, arguments, nil)
149+
if err != nil {
150+
return nil, err
151+
}
152+
obj, ok := toScalarMaybe(res)
153+
if !ok {
154+
return nil, fmt.Errorf("could not convert: %v", res)
155+
}
156+
// Accumulate the new series, and build the new index
157+
builder.push(obj)
158+
indexNames[i] = groupName
159+
}
160+
if err := builder.error(); err != nil {
161+
return nil, err
162+
}
163+
s := builder.toSeries(NewIndex(indexNames, self.lhsLabel), self.rhsLabel)
164+
return &s, nil
165+
}
166+
116167
func getSortedKeys(m map[string][]string) []string {
117168
keys := make([]string, 0, len(m))
118169
for k := range m {

dataframe/testdata/dataframe_groupby.expect.txt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,12 @@ cat 3
1111
dog 2
1212
Name: breed, dtype: int64
1313
Index(['cat', 'dog'], dtype='object', name='species')
14+
15+
species
16+
cat tabby cat, black cat, calico cat
17+
dog doberman dog, pug dog
18+
Name: breed, dtype: object
19+
20+
dataframe.Series
21+
breed
22+
species

dataframe/testdata/dataframe_groupby.star

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
load("dataframe.star", "dataframe")
22

33

4+
def example_animals(series):
5+
examples = ['{} {}'.format(series[i], series.name) for i in range(series.size)]
6+
return ', '.join(examples)
7+
8+
49
def f():
510
df = dataframe.DataFrame({"IDs": ["cat", "dog", "eel", "dog", "cat", "frog", "cat", "eel"],
611
"count": [1, 2, 3, 4, 5, 6, 7, 8]})
@@ -19,6 +24,14 @@ def f():
1924
num_breeds = df.groupby(['species'])['breed'].count()
2025
print(num_breeds)
2126
print(num_breeds.index)
27+
print('')
28+
29+
list_of_examples = df.groupby(['species'])['breed'].apply(example_animals)
30+
print(list_of_examples)
31+
print('')
32+
print(type(list_of_examples))
33+
print(list_of_examples.name)
34+
print(list_of_examples.index.name)
2235

2336

2437
f()

0 commit comments

Comments
 (0)