diff --git a/session_stats.go b/session_stats.go index c2cac8306..88286e507 100644 --- a/session_stats.go +++ b/session_stats.go @@ -30,6 +30,10 @@ func (session *Session) Count(bean ...interface{}) (int64, error) { args = session.statement.RawParams } + if len(session.statement.selectStr) > 0 { + sqlStr = "SELECT COUNT(*) FROM ("+sqlStr+") _TEMP_" + } + var total int64 err = session.queryRow(sqlStr, args...).Scan(&total) if err == sql.ErrNoRows || err == nil { diff --git a/session_stats_test.go b/session_stats_test.go index b66a84b4a..3d4159c40 100644 --- a/session_stats_test.go +++ b/session_stats_test.go @@ -124,6 +124,10 @@ func TestCount(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 1, cnt) + total, err = testEngine.Where(cond).Select(colName).Count(new(UserinfoCount)) + assert.NoError(t, err) + assert.EqualValues(t, 1, total) + total, err = testEngine.Where(cond).Count(new(UserinfoCount)) assert.NoError(t, err) assert.EqualValues(t, 1, total)