diff --git a/core/user.go b/core/user.go index 6c37c0c6..175375ff 100644 --- a/core/user.go +++ b/core/user.go @@ -70,8 +70,11 @@ type ( // Delete deletes a user from the datastore. Delete(context.Context, *User) error - // Count returns a count of active users. + // Count returns a count of human and machine users. Count(context.Context) (int64, error) + + // CountHuman returns a count of human users. + CountHuman(context.Context) (int64, error) } // UserService provides access to user account diff --git a/mock/mock_gen.go b/mock/mock_gen.go index 53ac2fd4..d1bec5b9 100644 --- a/mock/mock_gen.go +++ b/mock/mock_gen.go @@ -1880,6 +1880,21 @@ func (mr *MockUserStoreMockRecorder) Count(arg0 interface{}) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Count", reflect.TypeOf((*MockUserStore)(nil).Count), arg0) } +// CountHuman mocks base method +func (m *MockUserStore) CountHuman(arg0 context.Context) (int64, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CountHuman", arg0) + ret0, _ := ret[0].(int64) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CountHuman indicates an expected call of CountHuman +func (mr *MockUserStoreMockRecorder) CountHuman(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountHuman", reflect.TypeOf((*MockUserStore)(nil).CountHuman), arg0) +} + // Create mocks base method func (m *MockUserStore) Create(arg0 context.Context, arg1 *core.User) error { m.ctrl.T.Helper() diff --git a/store/user/user.go b/store/user/user.go index dfa7d62c..ebe4f311 100644 --- a/store/user/user.go +++ b/store/user/user.go @@ -159,11 +159,31 @@ func (s *userStore) Count(ctx context.Context) (int64, error) { return out, err } +// Count returns a count of active human users. +func (s *userStore) CountHuman(ctx context.Context) (int64, error) { + var out int64 + err := s.db.View(func(queryer db.Queryer, binder db.Binder) error { + params := toParams(&core.User{Machine: false}) + stmt, args, err := binder.BindNamed(queryCountHuman, params) + if err != nil { + return err + } + return queryer.QueryRow(stmt, args...).Scan(&out) + }) + return out, err +} + const queryCount = ` SELECT COUNT(*) FROM users ` +const queryCountHuman = ` +SELECT COUNT(*) +FROM users +WHERE user_machine = :user_machine +` + const queryBase = ` SELECT user_id diff --git a/store/user/user_test.go b/store/user/user_test.go index 778067ab..898720a8 100644 --- a/store/user/user_test.go +++ b/store/user/user_test.go @@ -66,6 +66,14 @@ func testUserCount(users *userStore) func(t *testing.T) { if got, want := count, int64(1); got != want { t.Errorf("Want user table count %d, got %d", want, got) } + + count, err = users.CountHuman(noContext) + if err != nil { + t.Error(err) + } + if got, want := count, int64(1); got != want { + t.Errorf("Want user table count %d, got %d", want, got) + } } }