diff --git a/submitqueue/extension/storage/BUILD.bazel b/submitqueue/extension/storage/BUILD.bazel index 867eef0f..25ef54ca 100644 --- a/submitqueue/extension/storage/BUILD.bazel +++ b/submitqueue/extension/storage/BUILD.bazel @@ -4,6 +4,7 @@ go_library( name = "storage", srcs = [ "batch_dependent_store.go", + "batch_state_membership_store.go", "batch_store.go", "build_store.go", "change_store.go", diff --git a/submitqueue/extension/storage/batch_state_membership_store.go b/submitqueue/extension/storage/batch_state_membership_store.go new file mode 100644 index 00000000..80bd2514 --- /dev/null +++ b/submitqueue/extension/storage/batch_state_membership_store.go @@ -0,0 +1,41 @@ +// Copyright (c) 2025 Uber Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package storage + +//go:generate mockgen -source=batch_state_membership_store.go -destination=mock/batch_state_membership_store_mock.go -package=mock + +import ( + "context" + + "github.com/uber/submitqueue/submitqueue/entity" +) + +// BatchStateMembershipStore records the app-maintained lookup from +// (queue, batch state) to batch IDs. The batch row remains authoritative: +// callers must resolve IDs through BatchStore and filter on the current +// persisted state. +type BatchStateMembershipStore interface { + // Add records that batchID belongs to (queue, state). Repeating the same + // Add is idempotent. + Add(ctx context.Context, queue string, state entity.BatchState, batchID string) error + + // Remove deletes a single membership row. Removing a missing row is + // idempotent and succeeds. + Remove(ctx context.Context, queue string, state entity.BatchState, batchID string) error + + // ListIDs returns every batch ID recorded for (queue, state). An empty slice + // means no membership rows exist for that key. + ListIDs(ctx context.Context, queue string, state entity.BatchState) ([]string, error) +} diff --git a/submitqueue/extension/storage/batch_store.go b/submitqueue/extension/storage/batch_store.go index 05e94fb8..f3c0a6fb 100644 --- a/submitqueue/extension/storage/batch_store.go +++ b/submitqueue/extension/storage/batch_store.go @@ -40,7 +40,4 @@ type BatchStore interface { // if the current persisted version matches oldVersion. If versions do not match, returns ErrVersionMismatch. // Version arithmetic is owned by the caller; the store performs a pure conditional write. UpdateScoreAndState(ctx context.Context, id string, oldVersion, newVersion int32, score float64, newState entity.BatchState) error - - // GetByQueueAndStates retrieves all batches that belong to the given queue and are in the given states. - GetByQueueAndStates(ctx context.Context, queue string, states []entity.BatchState) ([]entity.Batch, error) } diff --git a/submitqueue/extension/storage/mock/BUILD.bazel b/submitqueue/extension/storage/mock/BUILD.bazel index 55c5d808..2c80434a 100644 --- a/submitqueue/extension/storage/mock/BUILD.bazel +++ b/submitqueue/extension/storage/mock/BUILD.bazel @@ -4,6 +4,7 @@ go_library( name = "mock", srcs = [ "batch_dependent_store_mock.go", + "batch_state_membership_store_mock.go", "batch_store_mock.go", "build_store_mock.go", "change_store_mock.go", diff --git a/submitqueue/extension/storage/mock/batch_state_membership_store_mock.go b/submitqueue/extension/storage/mock/batch_state_membership_store_mock.go new file mode 100644 index 00000000..1501ea4e --- /dev/null +++ b/submitqueue/extension/storage/mock/batch_state_membership_store_mock.go @@ -0,0 +1,85 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: batch_state_membership_store.go +// +// Generated by this command: +// +// mockgen -source=batch_state_membership_store.go -destination=mock/batch_state_membership_store_mock.go -package=mock +// + +// Package mock is a generated GoMock package. +package mock + +import ( + context "context" + reflect "reflect" + + entity "github.com/uber/submitqueue/submitqueue/entity" + gomock "go.uber.org/mock/gomock" +) + +// MockBatchStateMembershipStore is a mock of BatchStateMembershipStore interface. +type MockBatchStateMembershipStore struct { + ctrl *gomock.Controller + recorder *MockBatchStateMembershipStoreMockRecorder + isgomock struct{} +} + +// MockBatchStateMembershipStoreMockRecorder is the mock recorder for MockBatchStateMembershipStore. +type MockBatchStateMembershipStoreMockRecorder struct { + mock *MockBatchStateMembershipStore +} + +// NewMockBatchStateMembershipStore creates a new mock instance. +func NewMockBatchStateMembershipStore(ctrl *gomock.Controller) *MockBatchStateMembershipStore { + mock := &MockBatchStateMembershipStore{ctrl: ctrl} + mock.recorder = &MockBatchStateMembershipStoreMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockBatchStateMembershipStore) EXPECT() *MockBatchStateMembershipStoreMockRecorder { + return m.recorder +} + +// Add mocks base method. +func (m *MockBatchStateMembershipStore) Add(ctx context.Context, queue string, state entity.BatchState, batchID string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Add", ctx, queue, state, batchID) + ret0, _ := ret[0].(error) + return ret0 +} + +// Add indicates an expected call of Add. +func (mr *MockBatchStateMembershipStoreMockRecorder) Add(ctx, queue, state, batchID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockBatchStateMembershipStore)(nil).Add), ctx, queue, state, batchID) +} + +// ListIDs mocks base method. +func (m *MockBatchStateMembershipStore) ListIDs(ctx context.Context, queue string, state entity.BatchState) ([]string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ListIDs", ctx, queue, state) + ret0, _ := ret[0].([]string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ListIDs indicates an expected call of ListIDs. +func (mr *MockBatchStateMembershipStoreMockRecorder) ListIDs(ctx, queue, state any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListIDs", reflect.TypeOf((*MockBatchStateMembershipStore)(nil).ListIDs), ctx, queue, state) +} + +// Remove mocks base method. +func (m *MockBatchStateMembershipStore) Remove(ctx context.Context, queue string, state entity.BatchState, batchID string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Remove", ctx, queue, state, batchID) + ret0, _ := ret[0].(error) + return ret0 +} + +// Remove indicates an expected call of Remove. +func (mr *MockBatchStateMembershipStoreMockRecorder) Remove(ctx, queue, state, batchID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Remove", reflect.TypeOf((*MockBatchStateMembershipStore)(nil).Remove), ctx, queue, state, batchID) +} diff --git a/submitqueue/extension/storage/mock/batch_store_mock.go b/submitqueue/extension/storage/mock/batch_store_mock.go index 48b6bdaf..429a6d2a 100644 --- a/submitqueue/extension/storage/mock/batch_store_mock.go +++ b/submitqueue/extension/storage/mock/batch_store_mock.go @@ -70,21 +70,6 @@ func (mr *MockBatchStoreMockRecorder) Get(ctx, id any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockBatchStore)(nil).Get), ctx, id) } -// GetByQueueAndStates mocks base method. -func (m *MockBatchStore) GetByQueueAndStates(ctx context.Context, queue string, states []entity.BatchState) ([]entity.Batch, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetByQueueAndStates", ctx, queue, states) - ret0, _ := ret[0].([]entity.Batch) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetByQueueAndStates indicates an expected call of GetByQueueAndStates. -func (mr *MockBatchStoreMockRecorder) GetByQueueAndStates(ctx, queue, states any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetByQueueAndStates", reflect.TypeOf((*MockBatchStore)(nil).GetByQueueAndStates), ctx, queue, states) -} - // UpdateScoreAndState mocks base method. func (m *MockBatchStore) UpdateScoreAndState(ctx context.Context, id string, oldVersion, newVersion int32, score float64, newState entity.BatchState) error { m.ctrl.T.Helper() diff --git a/submitqueue/extension/storage/mock/storage_mock.go b/submitqueue/extension/storage/mock/storage_mock.go index 4133bc2a..b19e355e 100644 --- a/submitqueue/extension/storage/mock/storage_mock.go +++ b/submitqueue/extension/storage/mock/storage_mock.go @@ -68,6 +68,20 @@ func (mr *MockStorageMockRecorder) GetBatchDependentStore() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetBatchDependentStore", reflect.TypeOf((*MockStorage)(nil).GetBatchDependentStore)) } +// GetBatchStateMembershipStore mocks base method. +func (m *MockStorage) GetBatchStateMembershipStore() storage.BatchStateMembershipStore { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetBatchStateMembershipStore") + ret0, _ := ret[0].(storage.BatchStateMembershipStore) + return ret0 +} + +// GetBatchStateMembershipStore indicates an expected call of GetBatchStateMembershipStore. +func (mr *MockStorageMockRecorder) GetBatchStateMembershipStore() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetBatchStateMembershipStore", reflect.TypeOf((*MockStorage)(nil).GetBatchStateMembershipStore)) +} + // GetBatchStore mocks base method. func (m *MockStorage) GetBatchStore() storage.BatchStore { m.ctrl.T.Helper() diff --git a/submitqueue/extension/storage/mysql/BUILD.bazel b/submitqueue/extension/storage/mysql/BUILD.bazel index 25fefdc5..297f1c95 100644 --- a/submitqueue/extension/storage/mysql/BUILD.bazel +++ b/submitqueue/extension/storage/mysql/BUILD.bazel @@ -4,6 +4,7 @@ go_library( name = "mysql", srcs = [ "batch_dependent_store.go", + "batch_state_membership_store.go", "batch_store.go", "build_store.go", "change_store.go", diff --git a/submitqueue/extension/storage/mysql/batch_state_membership_store.go b/submitqueue/extension/storage/mysql/batch_state_membership_store.go new file mode 100644 index 00000000..15f8e78c --- /dev/null +++ b/submitqueue/extension/storage/mysql/batch_state_membership_store.go @@ -0,0 +1,87 @@ +// Copyright (c) 2025 Uber Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mysql + +import ( + "context" + "database/sql" + "fmt" + + "github.com/uber-go/tally" + + "github.com/uber/submitqueue/core/metrics" + "github.com/uber/submitqueue/submitqueue/entity" + "github.com/uber/submitqueue/submitqueue/extension/storage" +) + +type batchStateMembershipStore struct { + db *sql.DB + scope tally.Scope +} + +// NewBatchStateMembershipStore creates a new MySQL-backed BatchStateMembershipStore. +func NewBatchStateMembershipStore(db *sql.DB, scope tally.Scope) storage.BatchStateMembershipStore { + return &batchStateMembershipStore{db: db, scope: scope} +} + +// Add records a batch's state membership. Duplicate membership is a retry-safe no-op. +func (s *batchStateMembershipStore) Add(ctx context.Context, queue string, state entity.BatchState, batchID string) (retErr error) { + op := metrics.Begin(s.scope, "add") + defer func() { op.Complete(retErr) }() + + const query = "INSERT INTO batch_state_membership (queue, state, batch_id) VALUES (?, ?, ?) ON DUPLICATE KEY UPDATE batch_id = batch_id" + if _, err := s.db.ExecContext(ctx, query, queue, state, batchID); err != nil { + return fmt.Errorf("failed to add batch state membership queue=%q state=%q batch_id=%q: %w", queue, state, batchID, err) + } + return nil +} + +// Remove deletes a batch's state membership. Missing rows are treated as already removed. +func (s *batchStateMembershipStore) Remove(ctx context.Context, queue string, state entity.BatchState, batchID string) (retErr error) { + op := metrics.Begin(s.scope, "remove") + defer func() { op.Complete(retErr) }() + + const query = "DELETE FROM batch_state_membership WHERE queue = ? AND state = ? AND batch_id = ?" + if _, err := s.db.ExecContext(ctx, query, queue, state, batchID); err != nil { + return fmt.Errorf("failed to remove batch state membership queue=%q state=%q batch_id=%q: %w", queue, state, batchID, err) + } + return nil +} + +// ListIDs returns batch IDs recorded for a queue and state. +func (s *batchStateMembershipStore) ListIDs(ctx context.Context, queue string, state entity.BatchState) (ret []string, retErr error) { + op := metrics.Begin(s.scope, "list_ids") + defer func() { op.Complete(retErr) }() + + const query = "SELECT batch_id FROM batch_state_membership WHERE queue = ? AND state = ?" + rows, err := s.db.QueryContext(ctx, query, queue, state) + if err != nil { + return nil, fmt.Errorf("failed to list batch state memberships queue=%q state=%q: %w", queue, state, err) + } + defer rows.Close() + + var ids []string + for rows.Next() { + var id string + if err := rows.Scan(&id); err != nil { + return nil, fmt.Errorf("failed to scan batch state membership queue=%q state=%q: %w", queue, state, err) + } + ids = append(ids, id) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("failed to iterate batch state memberships queue=%q state=%q: %w", queue, state, err) + } + return ids, nil +} diff --git a/submitqueue/extension/storage/mysql/batch_store.go b/submitqueue/extension/storage/mysql/batch_store.go index e6999aac..a7eef395 100644 --- a/submitqueue/extension/storage/mysql/batch_store.go +++ b/submitqueue/extension/storage/mysql/batch_store.go @@ -20,7 +20,6 @@ import ( "encoding/json" "errors" "fmt" - "strings" "github.com/go-sql-driver/mysql" "github.com/uber-go/tally" @@ -173,53 +172,3 @@ func (s *batchStore) UpdateScoreAndState(ctx context.Context, id string, oldVers return nil } - -// GetByQueueAndStates retrieves all batches that belong to the given queue and are in the given states. -func (s *batchStore) GetByQueueAndStates(ctx context.Context, queue string, states []entity.BatchState) (ret []entity.Batch, retErr error) { - op := metrics.Begin(s.scope, "get_by_queue_and_states") - defer func() { op.Complete(retErr) }() - - if len(states) == 0 { - return nil, nil - } - - query := "SELECT id, queue, contains, dependencies, score, state, version FROM batch WHERE queue = ? AND state IN (?" + strings.Repeat(", ?", len(states)-1) + ")" - - args := make([]any, 1+len(states)) - args[0] = queue - for i, state := range states { - args[i+1] = state - } - - rows, err := s.db.QueryContext(ctx, query, args...) - if err != nil { - return nil, fmt.Errorf("failed to query batches by queue=%q states=%v from the database: %w", queue, states, err) - } - defer rows.Close() - - var results []entity.Batch - for rows.Next() { - var batch entity.Batch - var containsJSON []byte - var dependenciesJSON []byte - - if err := rows.Scan(&batch.ID, &batch.Queue, &containsJSON, &dependenciesJSON, &batch.Score, &batch.State, &batch.Version); err != nil { - return nil, fmt.Errorf("failed to scan batch entity by queue=%q states=%v from the database: %w", queue, states, err) - } - - if err := json.Unmarshal(containsJSON, &batch.Contains); err != nil { - return nil, fmt.Errorf("failed to unmarshal contains for batch entity id=%s from the database: %w", batch.ID, err) - } - - if err := json.Unmarshal(dependenciesJSON, &batch.Dependencies); err != nil { - return nil, fmt.Errorf("failed to unmarshal dependencies for batch entity id=%s from the database: %w", batch.ID, err) - } - - results = append(results, batch) - } - if err := rows.Err(); err != nil { - return nil, fmt.Errorf("failed to iterate batches by queue=%q states=%v from the database: %w", queue, states, err) - } - - return results, nil -} diff --git a/submitqueue/extension/storage/mysql/schema/README.md b/submitqueue/extension/storage/mysql/schema/README.md index 9e8c36bf..fdcb5606 100644 --- a/submitqueue/extension/storage/mysql/schema/README.md +++ b/submitqueue/extension/storage/mysql/schema/README.md @@ -2,23 +2,26 @@ ## batch table -### Secondary index: `idx_queue_state (queue, state)` +The `batch` table is reachable only by its primary key (`id`). It carries no secondary index over mutable columns. Queue/state access patterns are expressed through an app-maintained companion table (see `batch_state_membership` below), and callers resolve batch IDs back through the authoritative `batch` row before making state decisions. -The `batch` table has a composite secondary index on `(queue, state)`. This index supports the `GetByQueueAndStates` query, which retrieves batches filtered by queue and one or more states. Without this index, the query would require a full table scan. +## batch_state_membership table -#### Trade-offs +`batch_state_membership` is the app-maintained lookup that answers "which batch IDs are recorded for this queue and state?" The table's primary key is `(queue, state, batch_id)`, so reads by queue/state are primary-key-prefix scans. The same shape ports to a key-value/document store that supports ranged scans over a partition/sort key, and it avoids a server-maintained secondary index over mutable `batch.state`. -- **Write overhead**: Every `INSERT` and `UPDATE` to the `batch` table must also update the secondary index, adding latency to write operations. -- **Storage cost**: The index consumes additional disk space proportional to the number of rows in the table. -- **Lock contention**: Under high write concurrency, index maintenance can increase lock contention on the affected index pages. +The table is not authoritative. The orchestrator resolves every listed `batch_id` with `BatchStore.Get` and filters by the current persisted `Batch.State`. This keeps storage generic: storage owns primitive membership records, while the orchestrator owns app concepts such as "active", "terminal", and "eligible for conflict analysis". -#### Future: Prune job +### Maintenance -As the `batch` table grows, the secondary index will grow with it, increasing storage costs and degrading write performance. To mitigate this, a prune job should be introduced to periodically delete batches in terminal states (`succeeded`, `failed`, `cancelled`) that are older than a configurable retention period. This keeps the table and its indexes bounded in size, ensuring consistent query and write performance over time. +The orchestrator writes the target non-terminal membership row before creating a batch or before CASing a batch into a new non-terminal state. After a successful CAS, it best-effort removes the previous non-terminal membership row. Terminal transitions do not add a target membership row; after the CAS succeeds, the previous non-terminal membership row is best-effort removed. + +Because membership writes and batch writes are independent, stale rows are expected in failure windows. Readers skip missing batch rows and filter stale state rows against the authoritative batch row. A terminal stale row may be removed on read because batch IDs are never reused. + +### Future: prune / reconcile job + +A reconcile job can periodically sweep dangling rows whose batch never landed and stale rows whose authoritative batch state no longer matches the membership state. This keeps the table bounded independently of read traffic. ## change table ### Composite primary key: `(queue, uri, request_id)` The `change` table records per-URI claims by in-flight requests. `request_id` is part of the primary key so that concurrent claims on the same URI by different requests coexist as distinct rows — a same-request retry collides on the PK and is a no-op (`INSERT IGNORE`), while a different-request claim is a new row that `GetByURI` surfaces for overlap detection. `queue` leads the key so queue-scoped lookups are primary-key-prefix scans and the table is shardable by queue. - diff --git a/submitqueue/extension/storage/mysql/schema/batch.sql b/submitqueue/extension/storage/mysql/schema/batch.sql index 8e7deda0..079cdecd 100644 --- a/submitqueue/extension/storage/mysql/schema/batch.sql +++ b/submitqueue/extension/storage/mysql/schema/batch.sql @@ -4,8 +4,7 @@ CREATE TABLE IF NOT EXISTS batch ( contains JSON NOT NULL, dependencies JSON NOT NULL, score DOUBLE NOT NULL, - state VARCHAR(255) NOT NUll, + state VARCHAR(255) NOT NULL, version INT NOT NULL, - PRIMARY KEY (id), - INDEX idx_queue_state (queue, state) + PRIMARY KEY (id) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; diff --git a/submitqueue/extension/storage/mysql/schema/batch_state_membership.sql b/submitqueue/extension/storage/mysql/schema/batch_state_membership.sql new file mode 100644 index 00000000..f6b35f14 --- /dev/null +++ b/submitqueue/extension/storage/mysql/schema/batch_state_membership.sql @@ -0,0 +1,9 @@ +-- batch_state_membership is the app-maintained lookup for +-- queue,state -> batch_ids. Batch rows are authoritative: callers resolve IDs +-- through the batch table and filter by the current persisted state. +CREATE TABLE IF NOT EXISTS batch_state_membership ( + queue VARCHAR(255) NOT NULL, + state VARCHAR(255) NOT NULL, + batch_id VARCHAR(255) NOT NULL, + PRIMARY KEY (queue, state, batch_id) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; diff --git a/submitqueue/extension/storage/mysql/storage.go b/submitqueue/extension/storage/mysql/storage.go index 4ba0cf41..f38aa862 100644 --- a/submitqueue/extension/storage/mysql/storage.go +++ b/submitqueue/extension/storage/mysql/storage.go @@ -24,27 +24,29 @@ import ( ) type mysqlStorage struct { - db *sql.DB - requestStore storage.RequestStore - changeStore storage.ChangeStore - batchStore storage.BatchStore - batchDependentStore storage.BatchDependentStore - buildStore storage.BuildStore - speculationTreeStore storage.SpeculationTreeStore - requestLogStore storage.RequestLogStore + db *sql.DB + requestStore storage.RequestStore + changeStore storage.ChangeStore + batchStore storage.BatchStore + batchStateMembershipStore storage.BatchStateMembershipStore + batchDependentStore storage.BatchDependentStore + buildStore storage.BuildStore + speculationTreeStore storage.SpeculationTreeStore + requestLogStore storage.RequestLogStore } // NewStorage creates a new MySQL storage. func NewStorage(db *sql.DB, scope tally.Scope) (storage.Storage, error) { return &mysqlStorage{ - db: db, - requestStore: NewRequestStore(db, scope.SubScope("request_store")), - changeStore: NewChangeStore(db, scope.SubScope("change_store")), - batchStore: NewBatchStore(db, scope.SubScope("batch_store")), - batchDependentStore: NewBatchDependentStore(db, scope.SubScope("batch_dependent_store")), - buildStore: NewBuildStore(db, scope.SubScope("build_store")), - speculationTreeStore: NewSpeculationTreeStore(db, scope.SubScope("speculation_tree_store")), - requestLogStore: NewRequestLogStore(db, scope.SubScope("request_log_store")), + db: db, + requestStore: NewRequestStore(db, scope.SubScope("request_store")), + changeStore: NewChangeStore(db, scope.SubScope("change_store")), + batchStore: NewBatchStore(db, scope.SubScope("batch_store")), + batchStateMembershipStore: NewBatchStateMembershipStore(db, scope.SubScope("batch_state_membership_store")), + batchDependentStore: NewBatchDependentStore(db, scope.SubScope("batch_dependent_store")), + buildStore: NewBuildStore(db, scope.SubScope("build_store")), + speculationTreeStore: NewSpeculationTreeStore(db, scope.SubScope("speculation_tree_store")), + requestLogStore: NewRequestLogStore(db, scope.SubScope("request_log_store")), }, nil } @@ -63,6 +65,11 @@ func (f *mysqlStorage) GetBatchStore() storage.BatchStore { return f.batchStore } +// GetBatchStateMembershipStore returns the MySQL-backed BatchStateMembershipStore. +func (f *mysqlStorage) GetBatchStateMembershipStore() storage.BatchStateMembershipStore { + return f.batchStateMembershipStore +} + // GetBatchDependentStore returns the MySQL-backed BatchDependentStore. func (f *mysqlStorage) GetBatchDependentStore() storage.BatchDependentStore { return f.batchDependentStore diff --git a/submitqueue/extension/storage/storage.go b/submitqueue/extension/storage/storage.go index a02bef73..c6841a1a 100644 --- a/submitqueue/extension/storage/storage.go +++ b/submitqueue/extension/storage/storage.go @@ -53,6 +53,9 @@ type Storage interface { // GetBatchStore returns the BatchStore instance. GetBatchStore() BatchStore + // GetBatchStateMembershipStore returns the BatchStateMembershipStore instance. + GetBatchStateMembershipStore() BatchStateMembershipStore + // GetBatchDependentStore returns the BatchDependentStore instance. GetBatchDependentStore() BatchDependentStore diff --git a/submitqueue/orchestrator/controller/batch/BUILD.bazel b/submitqueue/orchestrator/controller/batch/BUILD.bazel index 72a89398..89c1e84f 100644 --- a/submitqueue/orchestrator/controller/batch/BUILD.bazel +++ b/submitqueue/orchestrator/controller/batch/BUILD.bazel @@ -14,6 +14,7 @@ go_library( "//submitqueue/entity", "//submitqueue/extension/conflict", "//submitqueue/extension/storage", + "//submitqueue/orchestrator/controller/batchstate", "@com_github_uber_go_tally//:tally", "@org_uber_go_zap//:zap", ], diff --git a/submitqueue/orchestrator/controller/batch/batch.go b/submitqueue/orchestrator/controller/batch/batch.go index fef116fd..178ef38c 100644 --- a/submitqueue/orchestrator/controller/batch/batch.go +++ b/submitqueue/orchestrator/controller/batch/batch.go @@ -28,6 +28,7 @@ import ( "github.com/uber/submitqueue/submitqueue/entity" "github.com/uber/submitqueue/submitqueue/extension/conflict" "github.com/uber/submitqueue/submitqueue/extension/storage" + "github.com/uber/submitqueue/submitqueue/orchestrator/controller/batchstate" "go.uber.org/zap" ) @@ -134,14 +135,10 @@ func (c *Controller) Process(ctx context.Context, delivery consumer.Delivery) (r Version: 1, } - // Get active batches for this queue and ask the conflict analyzer which - // of them the new batch must serialize behind. The dependency set drives - // the speculation graph downstream. - activeBatches, err := c.store.GetBatchStore().GetByQueueAndStates(ctx, request.Queue, []entity.BatchState{ - entity.BatchStateCreated, - entity.BatchStateSpeculating, - entity.BatchStateMerging, - }) + // Ask the conflict analyzer which active batches the new batch must serialize + // behind. Membership rows are hints; batchstate.List resolves authoritative + // batch rows and returns only the current states requested here. + activeBatches, err := batchstate.List(ctx, c.store, request.Queue, batchstate.ConflictStates...) if err != nil { metrics.NamedCounter(c.metricsScope, opName, "batch_store_errors", 1) return fmt.Errorf("failed to get active batches for queue=%s: %w", request.Queue, err) @@ -279,7 +276,7 @@ func (c *Controller) Process(ctx context.Context, delivery consumer.Delivery) (r // Persist batch to storage. // This is the final operation that concludes the batch creation process. If it fails, BatchDependents will be pointing to a batch id that does not exist. // We do not reuse batch ids, a retry of this operation will create a new batch with a new ID. The downstream logic that operates on BatchDependent should be able to handle stale entries. - if err := c.store.GetBatchStore().Create(ctx, batch); err != nil { + if err := batchstate.Create(ctx, c.store, batch); err != nil { metrics.NamedCounter(c.metricsScope, opName, "batch_store_errors", 1) return fmt.Errorf("failed to create batch in batch store: %w", err) } diff --git a/submitqueue/orchestrator/controller/batch/batch_test.go b/submitqueue/orchestrator/controller/batch/batch_test.go index 792c9773..c8700acf 100644 --- a/submitqueue/orchestrator/controller/batch/batch_test.go +++ b/submitqueue/orchestrator/controller/batch/batch_test.go @@ -72,6 +72,32 @@ func testRequest() entity.Request { } } +func expectConflictMembership( + membershipStore *storagemock.MockBatchStateMembershipStore, + batchStore *storagemock.MockBatchStore, + queue string, + batches []entity.Batch, +) { + states := []entity.BatchState{ + entity.BatchStateCreated, + entity.BatchStateSpeculating, + entity.BatchStateMerging, + } + byState := make(map[entity.BatchState][]string, len(states)) + byID := make(map[string]entity.Batch, len(batches)) + for _, b := range batches { + byState[b.State] = append(byState[b.State], b.ID) + byID[b.ID] = b + } + for _, state := range states { + ids := byState[state] + membershipStore.EXPECT().ListIDs(gomock.Any(), queue, state).Return(ids, nil) + for _, id := range ids { + batchStore.EXPECT().Get(gomock.Any(), id).Return(byID[id], nil) + } + } +} + // newTestController creates a controller with test dependencies. // If mockStorage is nil, a default MockStorage with an empty batch store is created. // If analyzer is nil, the "all" conflict analyzer is used (every active batch becomes a dependency). @@ -81,8 +107,12 @@ func newTestController(t *testing.T, ctrl *gomock.Controller, cnt *countermock.M if mockStorage == nil { mockBatchStore := storagemock.NewMockBatchStore(ctrl) - mockBatchStore.EXPECT().GetByQueueAndStates(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() mockBatchStore.EXPECT().Create(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() + mockMembershipStore := storagemock.NewMockBatchStateMembershipStore(ctrl) + for _, state := range []entity.BatchState{entity.BatchStateCreated, entity.BatchStateSpeculating, entity.BatchStateMerging} { + mockMembershipStore.EXPECT().ListIDs(gomock.Any(), gomock.Any(), state).Return(nil, nil).AnyTimes() + } + mockMembershipStore.EXPECT().Add(gomock.Any(), gomock.Any(), entity.BatchStateCreated, gomock.Any()).Return(nil).AnyTimes() mockReqStore := storagemock.NewMockRequestStore(ctrl) req := testRequest() @@ -94,6 +124,7 @@ func newTestController(t *testing.T, ctrl *gomock.Controller, cnt *countermock.M mockStorage = storagemock.NewMockStorage(ctrl) mockStorage.EXPECT().GetBatchStore().Return(mockBatchStore).AnyTimes() + mockStorage.EXPECT().GetBatchStateMembershipStore().Return(mockMembershipStore).AnyTimes() mockStorage.EXPECT().GetBatchDependentStore().Return(mockBatchDependentStore).AnyTimes() mockStorage.EXPECT().GetRequestStore().Return(mockReqStore).AnyTimes() } @@ -212,15 +243,24 @@ func TestController_Process_WithDependencies(t *testing.T) { Version: 1, } - // Set up storage with active batches to become dependencies. + // Set up storage with active batches to become dependencies. The controller + // queries only Created/Speculating/Merging membership states. The Scored and + // Cancelling batches below must be excluded — note no BatchDependent + // expectations are registered for them, so the default all.New() analyzer + // would fail the test on an unexpected mock call if the filter let them + // through. activeBatches := []entity.Batch{ {ID: "test-queue/batch/1", Queue: "test-queue", State: entity.BatchStateCreated, Version: 1}, {ID: "test-queue/batch/2", Queue: "test-queue", State: entity.BatchStateSpeculating, Version: 2}, + {ID: "test-queue/batch/3", Queue: "test-queue", State: entity.BatchStateScored, Version: 1}, + {ID: "test-queue/batch/4", Queue: "test-queue", State: entity.BatchStateCancelling, Version: 1}, } mockBatchStore := storagemock.NewMockBatchStore(ctrl) - mockBatchStore.EXPECT().GetByQueueAndStates(gomock.Any(), "test-queue", gomock.Any()).Return(activeBatches, nil) mockBatchStore.EXPECT().Create(gomock.Any(), gomock.Any()).Return(nil) + mockMembershipStore := storagemock.NewMockBatchStateMembershipStore(ctrl) + expectConflictMembership(mockMembershipStore, mockBatchStore, "test-queue", activeBatches) + mockMembershipStore.EXPECT().Add(gomock.Any(), "test-queue", entity.BatchStateCreated, gomock.Any()).Return(nil) mockBatchDependentStore := storagemock.NewMockBatchDependentStore(ctrl) // batch/1 has no existing dependents. @@ -245,6 +285,7 @@ func TestController_Process_WithDependencies(t *testing.T) { mockStorage := storagemock.NewMockStorage(ctrl) mockStorage.EXPECT().GetBatchStore().Return(mockBatchStore).AnyTimes() + mockStorage.EXPECT().GetBatchStateMembershipStore().Return(mockMembershipStore).AnyTimes() mockStorage.EXPECT().GetBatchDependentStore().Return(mockBatchDependentStore).AnyTimes() mockStorage.EXPECT().GetRequestStore().Return(mockReqStore).AnyTimes() @@ -271,8 +312,10 @@ func TestController_Process_AnalyzerSelectsSubset(t *testing.T) { } mockBatchStore := storagemock.NewMockBatchStore(ctrl) - mockBatchStore.EXPECT().GetByQueueAndStates(gomock.Any(), "test-queue", gomock.Any()).Return(activeBatches, nil) mockBatchStore.EXPECT().Create(gomock.Any(), gomock.Any()).Return(nil) + mockMembershipStore := storagemock.NewMockBatchStateMembershipStore(ctrl) + expectConflictMembership(mockMembershipStore, mockBatchStore, "test-queue", activeBatches) + mockMembershipStore.EXPECT().Add(gomock.Any(), "test-queue", entity.BatchStateCreated, gomock.Any()).Return(nil) mockBatchDependentStore := storagemock.NewMockBatchDependentStore(ctrl) // Only batch/2 is selected by the analyzer, so only it gets a reverse-index update. @@ -289,6 +332,7 @@ func TestController_Process_AnalyzerSelectsSubset(t *testing.T) { mockStorage := storagemock.NewMockStorage(ctrl) mockStorage.EXPECT().GetBatchStore().Return(mockBatchStore).AnyTimes() + mockStorage.EXPECT().GetBatchStateMembershipStore().Return(mockMembershipStore).AnyTimes() mockStorage.EXPECT().GetBatchDependentStore().Return(mockBatchDependentStore).AnyTimes() mockStorage.EXPECT().GetRequestStore().Return(mockReqStore).AnyTimes() @@ -317,13 +361,15 @@ func TestController_Process_AnalyzerFailure(t *testing.T) { request := testRequest() mockBatchStore := storagemock.NewMockBatchStore(ctrl) - mockBatchStore.EXPECT().GetByQueueAndStates(gomock.Any(), "test-queue", gomock.Any()).Return(nil, nil) + mockMembershipStore := storagemock.NewMockBatchStateMembershipStore(ctrl) + expectConflictMembership(mockMembershipStore, mockBatchStore, "test-queue", nil) mockReqStore := storagemock.NewMockRequestStore(ctrl) mockReqStore.EXPECT().Get(gomock.Any(), request.ID).Return(request, nil) mockStorage := storagemock.NewMockStorage(ctrl) mockStorage.EXPECT().GetBatchStore().Return(mockBatchStore).AnyTimes() + mockStorage.EXPECT().GetBatchStateMembershipStore().Return(mockMembershipStore).AnyTimes() mockStorage.EXPECT().GetRequestStore().Return(mockReqStore).AnyTimes() analyzer := conflictmock.NewMockAnalyzer(ctrl) @@ -413,7 +459,8 @@ func TestController_Process_CASLostToCancel(t *testing.T) { request := testRequest() mockBatchStore := storagemock.NewMockBatchStore(ctrl) - mockBatchStore.EXPECT().GetByQueueAndStates(gomock.Any(), "test-queue", gomock.Any()).Return(nil, nil) + mockMembershipStore := storagemock.NewMockBatchStateMembershipStore(ctrl) + expectConflictMembership(mockMembershipStore, mockBatchStore, "test-queue", nil) // Create must NOT be called — gomock fails if it is. mockBatchDependentStore := storagemock.NewMockBatchDependentStore(ctrl) @@ -429,6 +476,7 @@ func TestController_Process_CASLostToCancel(t *testing.T) { mockStorage := storagemock.NewMockStorage(ctrl) mockStorage.EXPECT().GetBatchStore().Return(mockBatchStore).AnyTimes() + mockStorage.EXPECT().GetBatchStateMembershipStore().Return(mockMembershipStore).AnyTimes() mockStorage.EXPECT().GetBatchDependentStore().Return(mockBatchDependentStore).AnyTimes() mockStorage.EXPECT().GetRequestStore().Return(mockReqStore).AnyTimes() @@ -466,7 +514,8 @@ func TestController_Process_CASUnexpectedErrorPropagates(t *testing.T) { request := testRequest() mockBatchStore := storagemock.NewMockBatchStore(ctrl) - mockBatchStore.EXPECT().GetByQueueAndStates(gomock.Any(), "test-queue", gomock.Any()).Return(nil, nil) + mockMembershipStore := storagemock.NewMockBatchStateMembershipStore(ctrl) + expectConflictMembership(mockMembershipStore, mockBatchStore, "test-queue", nil) // Create must NOT be called — gomock fails if it is. mockBatchDependentStore := storagemock.NewMockBatchDependentStore(ctrl) @@ -481,6 +530,7 @@ func TestController_Process_CASUnexpectedErrorPropagates(t *testing.T) { mockStorage := storagemock.NewMockStorage(ctrl) mockStorage.EXPECT().GetBatchStore().Return(mockBatchStore).AnyTimes() + mockStorage.EXPECT().GetBatchStateMembershipStore().Return(mockMembershipStore).AnyTimes() mockStorage.EXPECT().GetBatchDependentStore().Return(mockBatchDependentStore).AnyTimes() mockStorage.EXPECT().GetRequestStore().Return(mockReqStore).AnyTimes() @@ -512,8 +562,10 @@ func TestController_Process_RecoveryAfterPriorCAS(t *testing.T) { request.Version = 2 // prior attempt bumped from 1 → 2 mockBatchStore := storagemock.NewMockBatchStore(ctrl) - mockBatchStore.EXPECT().GetByQueueAndStates(gomock.Any(), "test-queue", gomock.Any()).Return(nil, nil) mockBatchStore.EXPECT().Create(gomock.Any(), gomock.Any()).Return(nil) + mockMembershipStore := storagemock.NewMockBatchStateMembershipStore(ctrl) + expectConflictMembership(mockMembershipStore, mockBatchStore, "test-queue", nil) + mockMembershipStore.EXPECT().Add(gomock.Any(), "test-queue", entity.BatchStateCreated, gomock.Any()).Return(nil) mockBatchDependentStore := storagemock.NewMockBatchDependentStore(ctrl) mockBatchDependentStore.EXPECT().Create(gomock.Any(), gomock.Any()).Return(nil) @@ -526,6 +578,7 @@ func TestController_Process_RecoveryAfterPriorCAS(t *testing.T) { mockStorage := storagemock.NewMockStorage(ctrl) mockStorage.EXPECT().GetBatchStore().Return(mockBatchStore).AnyTimes() + mockStorage.EXPECT().GetBatchStateMembershipStore().Return(mockMembershipStore).AnyTimes() mockStorage.EXPECT().GetBatchDependentStore().Return(mockBatchDependentStore).AnyTimes() mockStorage.EXPECT().GetRequestStore().Return(mockReqStore).AnyTimes() diff --git a/submitqueue/orchestrator/controller/batchstate/BUILD.bazel b/submitqueue/orchestrator/controller/batchstate/BUILD.bazel new file mode 100644 index 00000000..ff0aca84 --- /dev/null +++ b/submitqueue/orchestrator/controller/batchstate/BUILD.bazel @@ -0,0 +1,12 @@ +load("@rules_go//go:def.bzl", "go_library") + +go_library( + name = "batchstate", + srcs = ["batchstate.go"], + importpath = "github.com/uber/submitqueue/submitqueue/orchestrator/controller/batchstate", + visibility = ["//visibility:public"], + deps = [ + "//submitqueue/entity", + "//submitqueue/extension/storage", + ], +) diff --git a/submitqueue/orchestrator/controller/batchstate/batchstate.go b/submitqueue/orchestrator/controller/batchstate/batchstate.go new file mode 100644 index 00000000..b18e11b9 --- /dev/null +++ b/submitqueue/orchestrator/controller/batchstate/batchstate.go @@ -0,0 +1,144 @@ +// Copyright (c) 2025 Uber Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package batchstate keeps batch-state membership maintenance in the +// orchestrator layer. Storage owns primitive records; this package owns the app +// semantics for creating batches, transitioning states, and resolving +// queue/state membership into authoritative Batch entities. +package batchstate + +import ( + "context" + "fmt" + + "github.com/uber/submitqueue/submitqueue/entity" + "github.com/uber/submitqueue/submitqueue/extension/storage" +) + +// NonTerminalStates are the batch states that should remain discoverable via +// membership lookups. +var NonTerminalStates = []entity.BatchState{ + entity.BatchStateCreated, + entity.BatchStateScored, + entity.BatchStateSpeculating, + entity.BatchStateMerging, + entity.BatchStateCancelling, +} + +// ConflictStates are the active states the batch controller considers when +// building conflict dependencies for a new batch. +var ConflictStates = []entity.BatchState{ + entity.BatchStateCreated, + entity.BatchStateSpeculating, + entity.BatchStateMerging, +} + +// Create records the initial non-terminal membership before creating the batch +// row, so a successfully persisted batch is not hidden from queue/state reads. +func Create(ctx context.Context, store storage.Storage, batch entity.Batch) error { + if !batch.State.IsTerminal() { + if err := store.GetBatchStateMembershipStore().Add(ctx, batch.Queue, batch.State, batch.ID); err != nil { + return fmt.Errorf("failed to add initial batch state membership for batch %s: %w", batch.ID, err) + } + } + if err := store.GetBatchStore().Create(ctx, batch); err != nil { + return err + } + return nil +} + +// UpdateState transitions a batch and maintains queue/state membership in the +// safe direction: add the target non-terminal state before the CAS, then +// best-effort remove the previous non-terminal state after the CAS succeeds. +func UpdateState(ctx context.Context, store storage.Storage, batch entity.Batch, newVersion int32, newState entity.BatchState) error { + if !newState.IsTerminal() { + if err := store.GetBatchStateMembershipStore().Add(ctx, batch.Queue, newState, batch.ID); err != nil { + return fmt.Errorf("failed to add batch state membership for batch %s state %s: %w", batch.ID, newState, err) + } + } + if err := store.GetBatchStore().UpdateState(ctx, batch.ID, batch.Version, newVersion, newState); err != nil { + return err + } + removePrevious(ctx, store, batch, newState) + return nil +} + +// UpdateScoreAndState is the score-writing variant of UpdateState. +func UpdateScoreAndState(ctx context.Context, store storage.Storage, batch entity.Batch, newVersion int32, score float64, newState entity.BatchState) error { + if !newState.IsTerminal() { + if err := store.GetBatchStateMembershipStore().Add(ctx, batch.Queue, newState, batch.ID); err != nil { + return fmt.Errorf("failed to add batch state membership for batch %s state %s: %w", batch.ID, newState, err) + } + } + if err := store.GetBatchStore().UpdateScoreAndState(ctx, batch.ID, batch.Version, newVersion, score, newState); err != nil { + return err + } + removePrevious(ctx, store, batch, newState) + return nil +} + +// List returns batches in queue whose current authoritative state is one of +// states. Membership rows are only hints: missing batch rows are skipped, stale +// rows are filtered, and terminal stale rows are best-effort removed. +func List(ctx context.Context, store storage.Storage, queue string, states ...entity.BatchState) ([]entity.Batch, error) { + if len(states) == 0 { + return nil, nil + } + + wanted := make(map[entity.BatchState]struct{}, len(states)) + for _, state := range states { + wanted[state] = struct{}{} + } + + seen := make(map[string]struct{}) + results := make([]entity.Batch, 0) + for _, state := range states { + ids, err := store.GetBatchStateMembershipStore().ListIDs(ctx, queue, state) + if err != nil { + return nil, fmt.Errorf("failed to list batch IDs for queue=%s state=%s: %w", queue, state, err) + } + for _, id := range ids { + if _, ok := seen[id]; ok { + continue + } + batch, err := store.GetBatchStore().Get(ctx, id) + if err != nil { + if storage.IsNotFound(err) { + continue + } + return nil, fmt.Errorf("failed to get batch id=%s from queue=%s state=%s membership: %w", id, queue, state, err) + } + if batch.Queue != queue { + continue + } + if batch.State.IsTerminal() { + _ = store.GetBatchStateMembershipStore().Remove(ctx, queue, state, id) + continue + } + if _, ok := wanted[batch.State]; !ok { + continue + } + seen[id] = struct{}{} + results = append(results, batch) + } + } + return results, nil +} + +func removePrevious(ctx context.Context, store storage.Storage, batch entity.Batch, newState entity.BatchState) { + if batch.State.IsTerminal() || batch.State == newState { + return + } + _ = store.GetBatchStateMembershipStore().Remove(ctx, batch.Queue, batch.State, batch.ID) +} diff --git a/submitqueue/orchestrator/controller/cancel/BUILD.bazel b/submitqueue/orchestrator/controller/cancel/BUILD.bazel index 466799f2..e2c29f2d 100644 --- a/submitqueue/orchestrator/controller/cancel/BUILD.bazel +++ b/submitqueue/orchestrator/controller/cancel/BUILD.bazel @@ -12,6 +12,7 @@ go_library( "//submitqueue/core/topickey", "//submitqueue/entity", "//submitqueue/extension/storage", + "//submitqueue/orchestrator/controller/batchstate", "@com_github_uber_go_tally//:tally", "@org_uber_go_zap//:zap", ], diff --git a/submitqueue/orchestrator/controller/cancel/cancel.go b/submitqueue/orchestrator/controller/cancel/cancel.go index 96a214ea..37f65490 100644 --- a/submitqueue/orchestrator/controller/cancel/cancel.go +++ b/submitqueue/orchestrator/controller/cancel/cancel.go @@ -65,6 +65,7 @@ import ( "github.com/uber/submitqueue/submitqueue/core/topickey" "github.com/uber/submitqueue/submitqueue/entity" "github.com/uber/submitqueue/submitqueue/extension/storage" + "github.com/uber/submitqueue/submitqueue/orchestrator/controller/batchstate" "go.uber.org/zap" ) @@ -184,21 +185,11 @@ func (c *Controller) markCancelling(ctx context.Context, request entity.Request) // findActiveBatch scans all active batches in the request's queue for one whose // Contains list includes the request. Returns (batch, true, nil) on a hit, // (zero, false, nil) when the request is not yet batched, and any storage -// error otherwise. -// -// BatchStateCancelling is included in the active-state list so an idempotent -// redelivery of the cancel message (the prior pass wrote the intent but the -// speculate hand-off publish failed) still resolves the batch and re-attempts -// the publish. +// error otherwise. NonTerminalStates includes Cancelling batches, so redelivery +// of a cancel whose speculate hand-off publish failed still resolves and retries. func (c *Controller) findActiveBatch(ctx context.Context, request entity.Request) (entity.Batch, bool, error) { // TODO: Scans all the batches in flight - make it more efficient? - active, err := c.store.GetBatchStore().GetByQueueAndStates(ctx, request.Queue, []entity.BatchState{ - entity.BatchStateCreated, - entity.BatchStateScored, - entity.BatchStateSpeculating, - entity.BatchStateMerging, - entity.BatchStateCancelling, - }) + active, err := batchstate.List(ctx, c.store, request.Queue, batchstate.NonTerminalStates...) if err != nil { c.metricsScope.Counter("batch_store_errors").Inc(1) return entity.Batch{}, false, fmt.Errorf("failed to get active batches for queue=%s: %w", request.Queue, err) @@ -271,7 +262,7 @@ func (c *Controller) cancelBatch(ctx context.Context, batch entity.Batch) error if batch.State != entity.BatchStateCancelling { newVersion := batch.Version + 1 - if err := c.store.GetBatchStore().UpdateState(ctx, batch.ID, batch.Version, newVersion, entity.BatchStateCancelling); err != nil { + if err := batchstate.UpdateState(ctx, c.store, batch, newVersion, entity.BatchStateCancelling); err != nil { c.metricsScope.Counter("batch_update_errors").Inc(1) // storage.ErrVersionMismatch here means the batch advanced concurrently // (e.g. speculate / merge progressed). Returned as-is for the base diff --git a/submitqueue/orchestrator/controller/cancel/cancel_test.go b/submitqueue/orchestrator/controller/cancel/cancel_test.go index a97079fb..d69d744d 100644 --- a/submitqueue/orchestrator/controller/cancel/cancel_test.go +++ b/submitqueue/orchestrator/controller/cancel/cancel_test.go @@ -71,6 +71,34 @@ func newDelivery(t *testing.T, ctrl *gomock.Controller, payload []byte, partitio return d } +func expectNonTerminalMembership( + membershipStore *storagemock.MockBatchStateMembershipStore, + batchStore *storagemock.MockBatchStore, + queue string, + batches []entity.Batch, +) { + states := []entity.BatchState{ + entity.BatchStateCreated, + entity.BatchStateScored, + entity.BatchStateSpeculating, + entity.BatchStateMerging, + entity.BatchStateCancelling, + } + byState := make(map[entity.BatchState][]string, len(states)) + byID := make(map[string]entity.Batch, len(batches)) + for _, b := range batches { + byState[b.State] = append(byState[b.State], b.ID) + byID[b.ID] = b + } + for _, state := range states { + ids := byState[state] + membershipStore.EXPECT().ListIDs(gomock.Any(), queue, state).Return(ids, nil) + for _, id := range ids { + batchStore.EXPECT().Get(gomock.Any(), id).Return(byID[id], nil) + } + } +} + func TestNewController(t *testing.T) { ctrl := gomock.NewController(t) registry, pub := newRegistry(t, ctrl) @@ -144,11 +172,13 @@ func TestProcess_CancelsUnbatchedRequest(t *testing.T) { ) batchStore := storagemock.NewMockBatchStore(ctrl) - batchStore.EXPECT().GetByQueueAndStates(gomock.Any(), "q", gomock.Any()).Return(nil, nil) + membershipStore := storagemock.NewMockBatchStateMembershipStore(ctrl) + expectNonTerminalMembership(membershipStore, batchStore, "q", nil) store := storagemock.NewMockStorage(ctrl) store.EXPECT().GetRequestStore().Return(reqStore).AnyTimes() store.EXPECT().GetBatchStore().Return(batchStore).AnyTimes() + store.EXPECT().GetBatchStateMembershipStore().Return(membershipStore).AnyTimes() controller := newController(t, store, registry) err := controller.Process(context.Background(), newDelivery(t, ctrl, cancelPayload(t, "q/1", "user changed mind"), "q/1")) @@ -175,11 +205,13 @@ func TestProcess_AlreadyCancelling_SkipsMarkCancelling(t *testing.T) { reqStore.EXPECT().UpdateState(gomock.Any(), "q/1", int32(3), int32(4), entity.RequestStateCancelled).Return(nil) batchStore := storagemock.NewMockBatchStore(ctrl) - batchStore.EXPECT().GetByQueueAndStates(gomock.Any(), "q", gomock.Any()).Return(nil, nil) + membershipStore := storagemock.NewMockBatchStateMembershipStore(ctrl) + expectNonTerminalMembership(membershipStore, batchStore, "q", nil) store := storagemock.NewMockStorage(ctrl) store.EXPECT().GetRequestStore().Return(reqStore).AnyTimes() store.EXPECT().GetBatchStore().Return(batchStore).AnyTimes() + store.EXPECT().GetBatchStateMembershipStore().Return(membershipStore).AnyTimes() controller := newController(t, store, registry) err := controller.Process(context.Background(), newDelivery(t, ctrl, cancelPayload(t, "q/1", ""), "q/1")) @@ -228,11 +260,13 @@ func TestProcess_UnbatchedVersionMismatch_Retryable(t *testing.T) { ) batchStore := storagemock.NewMockBatchStore(ctrl) - batchStore.EXPECT().GetByQueueAndStates(gomock.Any(), "q", gomock.Any()).Return(nil, nil) + membershipStore := storagemock.NewMockBatchStateMembershipStore(ctrl) + expectNonTerminalMembership(membershipStore, batchStore, "q", nil) store := storagemock.NewMockStorage(ctrl) store.EXPECT().GetRequestStore().Return(reqStore).AnyTimes() store.EXPECT().GetBatchStore().Return(batchStore).AnyTimes() + store.EXPECT().GetBatchStateMembershipStore().Return(membershipStore).AnyTimes() controller := newController(t, store, registry) err := controller.Process(context.Background(), newDelivery(t, ctrl, cancelPayload(t, "q/1", ""), "q/1")) @@ -276,13 +310,17 @@ func TestProcess_BatchPath_HandsOffToSpeculate(t *testing.T) { reqStore.EXPECT().UpdateState(gomock.Any(), "q/1", int32(2), int32(3), entity.RequestStateCancelling).Return(nil) batchStore := storagemock.NewMockBatchStore(ctrl) - batchStore.EXPECT().GetByQueueAndStates(gomock.Any(), "q", gomock.Any()).Return([]entity.Batch{batch}, nil) + membershipStore := storagemock.NewMockBatchStateMembershipStore(ctrl) + expectNonTerminalMembership(membershipStore, batchStore, "q", []entity.Batch{batch}) + membershipStore.EXPECT().Add(gomock.Any(), "q", entity.BatchStateCancelling, batch.ID).Return(nil) // Single batch CAS: intent only. No terminal CAS. batchStore.EXPECT().UpdateState(gomock.Any(), batch.ID, int32(3), int32(4), entity.BatchStateCancelling).Return(nil) + membershipStore.EXPECT().Remove(gomock.Any(), "q", entity.BatchStateSpeculating, batch.ID).Return(nil) store := storagemock.NewMockStorage(ctrl) store.EXPECT().GetRequestStore().Return(reqStore).AnyTimes() store.EXPECT().GetBatchStore().Return(batchStore).AnyTimes() + store.EXPECT().GetBatchStateMembershipStore().Return(membershipStore).AnyTimes() // BatchDependentStore and BuildStore must NOT be touched — speculate owns those now. controller := newController(t, store, registry) @@ -325,12 +363,14 @@ func TestProcess_BatchAlreadyCancelling_RepublishesToSpeculate(t *testing.T) { // No request UpdateState — already in Cancelling. batchStore := storagemock.NewMockBatchStore(ctrl) - batchStore.EXPECT().GetByQueueAndStates(gomock.Any(), "q", gomock.Any()).Return([]entity.Batch{batch}, nil) + membershipStore := storagemock.NewMockBatchStateMembershipStore(ctrl) + expectNonTerminalMembership(membershipStore, batchStore, "q", []entity.Batch{batch}) // No batch UpdateState — already in Cancelling. store := storagemock.NewMockStorage(ctrl) store.EXPECT().GetRequestStore().Return(reqStore).AnyTimes() store.EXPECT().GetBatchStore().Return(batchStore).AnyTimes() + store.EXPECT().GetBatchStateMembershipStore().Return(membershipStore).AnyTimes() controller := newController(t, store, registry) err := controller.Process(context.Background(), newDelivery(t, ctrl, cancelPayload(t, "q/1", ""), "q/1")) @@ -356,13 +396,16 @@ func TestProcess_BatchIntentVersionMismatch_Retryable(t *testing.T) { reqStore.EXPECT().UpdateState(gomock.Any(), "q/1", int32(2), int32(3), entity.RequestStateCancelling).Return(nil) batchStore := storagemock.NewMockBatchStore(ctrl) - batchStore.EXPECT().GetByQueueAndStates(gomock.Any(), "q", gomock.Any()).Return([]entity.Batch{batch}, nil) + membershipStore := storagemock.NewMockBatchStateMembershipStore(ctrl) + expectNonTerminalMembership(membershipStore, batchStore, "q", []entity.Batch{batch}) + membershipStore.EXPECT().Add(gomock.Any(), "q", entity.BatchStateCancelling, batch.ID).Return(nil) batchStore.EXPECT().UpdateState(gomock.Any(), batch.ID, int32(1), int32(2), entity.BatchStateCancelling). Return(storage.ErrVersionMismatch) store := storagemock.NewMockStorage(ctrl) store.EXPECT().GetRequestStore().Return(reqStore).AnyTimes() store.EXPECT().GetBatchStore().Return(batchStore).AnyTimes() + store.EXPECT().GetBatchStateMembershipStore().Return(membershipStore).AnyTimes() controller := newController(t, store, registry) err := controller.Process(context.Background(), newDelivery(t, ctrl, cancelPayload(t, "q/1", ""), "q/1")) diff --git a/submitqueue/orchestrator/controller/dlq/BUILD.bazel b/submitqueue/orchestrator/controller/dlq/BUILD.bazel index f38da3af..7857b03d 100644 --- a/submitqueue/orchestrator/controller/dlq/BUILD.bazel +++ b/submitqueue/orchestrator/controller/dlq/BUILD.bazel @@ -16,6 +16,7 @@ go_library( "//core/metrics", "//submitqueue/entity", "//submitqueue/extension/storage", + "//submitqueue/orchestrator/controller/batchstate", "@com_github_uber_go_tally//:tally", "@org_uber_go_zap//:zap", ], diff --git a/submitqueue/orchestrator/controller/dlq/batch_test.go b/submitqueue/orchestrator/controller/dlq/batch_test.go index db55ed6f..1f41bbea 100644 --- a/submitqueue/orchestrator/controller/dlq/batch_test.go +++ b/submitqueue/orchestrator/controller/dlq/batch_test.go @@ -60,6 +60,7 @@ func TestDLQBatchController_Process_FailsAndFansOut(t *testing.T) { store := storagemock.NewMockStorage(ctrl) store.EXPECT().GetBatchStore().Return(batchStore).AnyTimes() + expectBatchFailedTransition(ctrl, store, "q/batch/9", "q", entity.BatchStateMerging) store.EXPECT().GetRequestStore().Return(requestStore).AnyTimes() store.EXPECT().GetRequestLogStore().Return(logStore).AnyTimes() diff --git a/submitqueue/orchestrator/controller/dlq/buildsignal_test.go b/submitqueue/orchestrator/controller/dlq/buildsignal_test.go index 329f35b5..1f85cdb9 100644 --- a/submitqueue/orchestrator/controller/dlq/buildsignal_test.go +++ b/submitqueue/orchestrator/controller/dlq/buildsignal_test.go @@ -67,6 +67,7 @@ func TestDLQBuildSignalController_Process_FansOutToBatch(t *testing.T) { store := storagemock.NewMockStorage(ctrl) store.EXPECT().GetBuildStore().Return(buildStore).AnyTimes() store.EXPECT().GetBatchStore().Return(batchStore).AnyTimes() + expectBatchFailedTransition(ctrl, store, "q/batch/2", "q", entity.BatchStateSpeculating) store.EXPECT().GetRequestStore().Return(requestStore).AnyTimes() store.EXPECT().GetRequestLogStore().Return(logStore).AnyTimes() diff --git a/submitqueue/orchestrator/controller/dlq/dlq.go b/submitqueue/orchestrator/controller/dlq/dlq.go index 77570ca4..3795f333 100644 --- a/submitqueue/orchestrator/controller/dlq/dlq.go +++ b/submitqueue/orchestrator/controller/dlq/dlq.go @@ -41,6 +41,7 @@ import ( "github.com/uber/submitqueue/core/consumer" "github.com/uber/submitqueue/submitqueue/entity" "github.com/uber/submitqueue/submitqueue/extension/storage" + "github.com/uber/submitqueue/submitqueue/orchestrator/controller/batchstate" "go.uber.org/zap" ) @@ -158,7 +159,7 @@ func failBatch(ctx context.Context, store storage.Storage, logger *zap.SugaredLo ) } else { newVersion := batch.Version + 1 - if err := store.GetBatchStore().UpdateState(ctx, batchID, batch.Version, newVersion, entity.BatchStateFailed); err != nil { + if err := batchstate.UpdateState(ctx, store, batch, newVersion, entity.BatchStateFailed); err != nil { return fmt.Errorf("failed to update batch %s state to failed: %w", batchID, err) } logger.Infow("dlq reconcile: batch marked failed", diff --git a/submitqueue/orchestrator/controller/dlq/dlq_test.go b/submitqueue/orchestrator/controller/dlq/dlq_test.go index 4ace4fe4..de039c74 100644 --- a/submitqueue/orchestrator/controller/dlq/dlq_test.go +++ b/submitqueue/orchestrator/controller/dlq/dlq_test.go @@ -183,6 +183,12 @@ func TestFailRequest_GenericGetErrorIsNonRetryable(t *testing.T) { // failBatch +func expectBatchFailedTransition(ctrl *gomock.Controller, store *storagemock.MockStorage, batchID, queue string, oldState entity.BatchState) { + membershipStore := storagemock.NewMockBatchStateMembershipStore(ctrl) + membershipStore.EXPECT().Remove(gomock.Any(), queue, oldState, batchID).Return(nil).AnyTimes() + store.EXPECT().GetBatchStateMembershipStore().Return(membershipStore).AnyTimes() +} + func TestFailBatch_TransitionsAndFansOut(t *testing.T) { ctrl := gomock.NewController(t) @@ -208,6 +214,7 @@ func TestFailBatch_TransitionsAndFansOut(t *testing.T) { store := storagemock.NewMockStorage(ctrl) store.EXPECT().GetBatchStore().Return(batchStore).AnyTimes() + expectBatchFailedTransition(ctrl, store, "q/batch/1", "q", entity.BatchStateMerging) store.EXPECT().GetRequestStore().Return(requestStore).AnyTimes() store.EXPECT().GetRequestLogStore().Return(logStore).AnyTimes() @@ -239,6 +246,7 @@ func TestFailBatch_AlreadyTerminalFansOutOnly(t *testing.T) { store := storagemock.NewMockStorage(ctrl) store.EXPECT().GetBatchStore().Return(batchStore).AnyTimes() + expectBatchFailedTransition(ctrl, store, "q/batch/1", "q", entity.BatchStateCancelling) store.EXPECT().GetRequestStore().Return(requestStore).AnyTimes() store.EXPECT().GetRequestLogStore().Return(logStore).AnyTimes() @@ -273,6 +281,7 @@ func TestFailBatch_CancellingTransitionsToFailed(t *testing.T) { store := storagemock.NewMockStorage(ctrl) store.EXPECT().GetBatchStore().Return(batchStore).AnyTimes() + expectBatchFailedTransition(ctrl, store, "q/batch/1", "q", entity.BatchStateCancelling) store.EXPECT().GetRequestStore().Return(requestStore).AnyTimes() store.EXPECT().GetRequestLogStore().Return(logStore).AnyTimes() diff --git a/submitqueue/orchestrator/controller/merge/BUILD.bazel b/submitqueue/orchestrator/controller/merge/BUILD.bazel index cd74e742..6ae4d65e 100644 --- a/submitqueue/orchestrator/controller/merge/BUILD.bazel +++ b/submitqueue/orchestrator/controller/merge/BUILD.bazel @@ -13,6 +13,7 @@ go_library( "//submitqueue/entity", "//submitqueue/extension/pusher", "//submitqueue/extension/storage", + "//submitqueue/orchestrator/controller/batchstate", "@com_github_uber_go_tally//:tally", "@org_uber_go_zap//:zap", ], diff --git a/submitqueue/orchestrator/controller/merge/merge.go b/submitqueue/orchestrator/controller/merge/merge.go index 72607ec6..9726fd4e 100644 --- a/submitqueue/orchestrator/controller/merge/merge.go +++ b/submitqueue/orchestrator/controller/merge/merge.go @@ -29,6 +29,7 @@ import ( "github.com/uber/submitqueue/submitqueue/entity" "github.com/uber/submitqueue/submitqueue/extension/pusher" "github.com/uber/submitqueue/submitqueue/extension/storage" + "github.com/uber/submitqueue/submitqueue/orchestrator/controller/batchstate" ) // Controller handles merge queue messages. It loads every request in a batch, @@ -153,7 +154,7 @@ func (c *Controller) Process(ctx context.Context, delivery consumer.Delivery) (r } newVersion := batch.Version + 1 - if err := c.store.GetBatchStore().UpdateState(ctx, batch.ID, batch.Version, newVersion, newState); err != nil { + if err := batchstate.UpdateState(ctx, c.store, batch, newVersion, newState); err != nil { coremetrics.NamedCounter(c.metricsScope, "process", "state_update_errors", 1) return fmt.Errorf("failed to transition batch %s to %s: %w", batch.ID, newState, err) } diff --git a/submitqueue/orchestrator/controller/merge/merge_test.go b/submitqueue/orchestrator/controller/merge/merge_test.go index ffdcddd8..fb9eea5a 100644 --- a/submitqueue/orchestrator/controller/merge/merge_test.go +++ b/submitqueue/orchestrator/controller/merge/merge_test.go @@ -76,6 +76,12 @@ func newPusherFactory(ctrl *gomock.Controller, p pusher.Pusher) pusher.Factory { return f } +func expectTerminalBatchTransition(ctrl *gomock.Controller, store *storagemock.MockStorage, batch entity.Batch) { + membershipStore := storagemock.NewMockBatchStateMembershipStore(ctrl) + membershipStore.EXPECT().Remove(gomock.Any(), batch.Queue, batch.State, batch.ID).Return(nil).AnyTimes() + store.EXPECT().GetBatchStateMembershipStore().Return(membershipStore).AnyTimes() +} + func TestNewController(t *testing.T) { ctrl := gomock.NewController(t) store := storagemock.NewMockStorage(ctrl) @@ -117,6 +123,7 @@ func TestController_Process_SuccessfulMerge(t *testing.T) { store := storagemock.NewMockStorage(ctrl) store.EXPECT().GetBatchStore().Return(batchStore).AnyTimes() + expectTerminalBatchTransition(ctrl, store, batch) mockPusher := pushermock.NewMockPusher(ctrl) mockPusher.EXPECT().Push(gomock.Any(), gomock.Any()).DoAndReturn( @@ -172,6 +179,7 @@ func TestController_Process_ForwardsBatchToPusher(t *testing.T) { store := storagemock.NewMockStorage(ctrl) store.EXPECT().GetBatchStore().Return(batchStore).AnyTimes() + expectTerminalBatchTransition(ctrl, store, batch) mockPusher := pushermock.NewMockPusher(ctrl) mockPusher.EXPECT().Push(gomock.Any(), gomock.Any()).DoAndReturn( @@ -216,6 +224,7 @@ func TestController_Process_PushConflictMarksBatchFailed(t *testing.T) { store := storagemock.NewMockStorage(ctrl) store.EXPECT().GetBatchStore().Return(batchStore).AnyTimes() + expectTerminalBatchTransition(ctrl, store, batch) mockPusher := pushermock.NewMockPusher(ctrl) mockPusher.EXPECT().Push(gomock.Any(), gomock.Any()).Return( @@ -256,6 +265,7 @@ func TestController_Process_PushInfraFailureReturnsError(t *testing.T) { store := storagemock.NewMockStorage(ctrl) store.EXPECT().GetBatchStore().Return(batchStore).AnyTimes() + expectTerminalBatchTransition(ctrl, store, batch) mockPusher := pushermock.NewMockPusher(ctrl) mockPusher.EXPECT().Push(gomock.Any(), gomock.Any()).Return( @@ -414,6 +424,7 @@ func TestController_Process_PublishFailureSurfaces(t *testing.T) { store := storagemock.NewMockStorage(ctrl) store.EXPECT().GetBatchStore().Return(batchStore).AnyTimes() + expectTerminalBatchTransition(ctrl, store, batch) mockPusher := pushermock.NewMockPusher(ctrl) mockPusher.EXPECT().Push(gomock.Any(), gomock.Any()).Return( diff --git a/submitqueue/orchestrator/controller/score/BUILD.bazel b/submitqueue/orchestrator/controller/score/BUILD.bazel index 96976e6c..9d2ebb92 100644 --- a/submitqueue/orchestrator/controller/score/BUILD.bazel +++ b/submitqueue/orchestrator/controller/score/BUILD.bazel @@ -14,6 +14,7 @@ go_library( "//submitqueue/entity", "//submitqueue/extension/scorer", "//submitqueue/extension/storage", + "//submitqueue/orchestrator/controller/batchstate", "@com_github_uber_go_tally//:tally", "@org_uber_go_zap//:zap", ], diff --git a/submitqueue/orchestrator/controller/score/score.go b/submitqueue/orchestrator/controller/score/score.go index 272713bc..bb0c20aa 100644 --- a/submitqueue/orchestrator/controller/score/score.go +++ b/submitqueue/orchestrator/controller/score/score.go @@ -27,6 +27,7 @@ import ( "github.com/uber/submitqueue/submitqueue/entity" "github.com/uber/submitqueue/submitqueue/extension/scorer" "github.com/uber/submitqueue/submitqueue/extension/storage" + "github.com/uber/submitqueue/submitqueue/orchestrator/controller/batchstate" "go.uber.org/zap" ) @@ -140,11 +141,12 @@ func (c *Controller) Process(ctx context.Context, delivery consumer.Delivery) (r // Atomically update score and state to "scored" in the database newVersion := batch.Version + 1 - if err := c.store.GetBatchStore().UpdateScoreAndState(ctx, batch.ID, batch.Version, newVersion, batchScore, entity.BatchStateScored); err != nil { + if err := batchstate.UpdateScoreAndState(ctx, c.store, batch, newVersion, batchScore, entity.BatchStateScored); err != nil { metrics.NamedCounter(c.metricsScope, opName, "storage_errors", 1) return fmt.Errorf("failed to update score for batch %s: %w", batch.ID, err) } batch.Version = newVersion + batch.State = entity.BatchStateScored c.logger.Infow("scored batch", "batch_id", batch.ID, diff --git a/submitqueue/orchestrator/controller/score/score_test.go b/submitqueue/orchestrator/controller/score/score_test.go index f1b91ddd..ac54a441 100644 --- a/submitqueue/orchestrator/controller/score/score_test.go +++ b/submitqueue/orchestrator/controller/score/score_test.go @@ -87,6 +87,13 @@ func mockChangeStore(ctrl *gomock.Controller, requests ...entity.Request) *stora return cs } +func expectScoreMembership(ctrl *gomock.Controller, store *storagemock.MockStorage, batch entity.Batch) { + membershipStore := storagemock.NewMockBatchStateMembershipStore(ctrl) + membershipStore.EXPECT().Add(gomock.Any(), batch.Queue, entity.BatchStateScored, batch.ID).Return(nil).AnyTimes() + membershipStore.EXPECT().Remove(gomock.Any(), batch.Queue, batch.State, batch.ID).Return(nil).AnyTimes() + store.EXPECT().GetBatchStateMembershipStore().Return(membershipStore).AnyTimes() +} + // newMockStorage creates a MockStorage with a MockBatchStore, MockRequestStore, and MockChangeStore. func newMockStorage(ctrl *gomock.Controller, batch entity.Batch, request entity.Request) *storagemock.MockStorage { mockBatchStore := storagemock.NewMockBatchStore(ctrl) @@ -98,6 +105,7 @@ func newMockStorage(ctrl *gomock.Controller, batch entity.Batch, request entity. store := storagemock.NewMockStorage(ctrl) store.EXPECT().GetBatchStore().Return(mockBatchStore).AnyTimes() + expectScoreMembership(ctrl, store, batch) store.EXPECT().GetRequestStore().Return(mockRequestStore).AnyTimes() store.EXPECT().GetChangeStore().Return(mockChangeStore(ctrl, request)).AnyTimes() return store @@ -190,6 +198,7 @@ func TestController_Process_BatchLevelScore(t *testing.T) { store := storagemock.NewMockStorage(ctrl) store.EXPECT().GetBatchStore().Return(mockBatchStore).AnyTimes() + expectScoreMembership(ctrl, store, batch) // The controller passes the batch identity to the scorer and persists its score. mockScorer := scorermock.NewMockScorer(ctrl) @@ -246,6 +255,7 @@ func TestController_Process_ScorerFailure(t *testing.T) { store := storagemock.NewMockStorage(ctrl) store.EXPECT().GetBatchStore().Return(mockBatchStore).AnyTimes() + expectScoreMembership(ctrl, store, batch) store.EXPECT().GetRequestStore().Return(mockRequestStore).AnyTimes() store.EXPECT().GetChangeStore().Return(mockChangeStore(ctrl, request)).AnyTimes() @@ -278,6 +288,7 @@ func TestController_Process_UpdateScoreFailure(t *testing.T) { store := storagemock.NewMockStorage(ctrl) store.EXPECT().GetBatchStore().Return(mockBatchStore).AnyTimes() + expectScoreMembership(ctrl, store, batch) store.EXPECT().GetRequestStore().Return(mockRequestStore).AnyTimes() store.EXPECT().GetChangeStore().Return(mockChangeStore(ctrl, request)).AnyTimes() diff --git a/submitqueue/orchestrator/controller/speculate/BUILD.bazel b/submitqueue/orchestrator/controller/speculate/BUILD.bazel index fa1e9ae7..d7626f19 100644 --- a/submitqueue/orchestrator/controller/speculate/BUILD.bazel +++ b/submitqueue/orchestrator/controller/speculate/BUILD.bazel @@ -12,6 +12,7 @@ go_library( "//submitqueue/core/topickey", "//submitqueue/entity", "//submitqueue/extension/storage", + "//submitqueue/orchestrator/controller/batchstate", "@com_github_uber_go_tally//:tally", "@org_uber_go_zap//:zap", ], diff --git a/submitqueue/orchestrator/controller/speculate/speculate.go b/submitqueue/orchestrator/controller/speculate/speculate.go index 2631c42e..888b520f 100644 --- a/submitqueue/orchestrator/controller/speculate/speculate.go +++ b/submitqueue/orchestrator/controller/speculate/speculate.go @@ -26,6 +26,7 @@ import ( "github.com/uber/submitqueue/submitqueue/core/topickey" "github.com/uber/submitqueue/submitqueue/entity" "github.com/uber/submitqueue/submitqueue/extension/storage" + "github.com/uber/submitqueue/submitqueue/orchestrator/controller/batchstate" "go.uber.org/zap" ) @@ -162,7 +163,7 @@ func (c *Controller) startSpeculation(ctx context.Context, batch entity.Batch) e // Optimistic CAS: if the version has already advanced (concurrent speculate), // the next event will see the new state and behave correctly. newVersion := batch.Version + 1 - if err := c.store.GetBatchStore().UpdateState(ctx, batch.ID, batch.Version, newVersion, entity.BatchStateSpeculating); err != nil { + if err := batchstate.UpdateState(ctx, c.store, batch, newVersion, entity.BatchStateSpeculating); err != nil { metrics.NamedCounter(c.metricsScope, opName, "storage_errors", 1) return fmt.Errorf("failed to update batch %s state to speculating: %w", batch.ID, err) } @@ -223,7 +224,7 @@ func (c *Controller) tryFinalize(ctx context.Context, batch entity.Batch) error } newVersion := batch.Version + 1 - if err := c.store.GetBatchStore().UpdateState(ctx, batch.ID, batch.Version, newVersion, entity.BatchStateMerging); err != nil { + if err := batchstate.UpdateState(ctx, c.store, batch, newVersion, entity.BatchStateMerging); err != nil { metrics.NamedCounter(c.metricsScope, opName, "storage_errors", 1) return fmt.Errorf("failed to update batch %s state to merging: %w", batch.ID, err) } @@ -245,7 +246,7 @@ func (c *Controller) failOnDependency(ctx context.Context, batch entity.Batch, d ) newVersion := batch.Version + 1 - if err := c.store.GetBatchStore().UpdateState(ctx, batch.ID, batch.Version, newVersion, entity.BatchStateFailed); err != nil { + if err := batchstate.UpdateState(ctx, c.store, batch, newVersion, entity.BatchStateFailed); err != nil { metrics.NamedCounter(c.metricsScope, opName, "storage_errors", 1) return fmt.Errorf("failed to update batch %s state to failed: %w", batch.ID, err) } @@ -307,7 +308,7 @@ func (c *Controller) cancelBatch(ctx context.Context, batch entity.Batch) error } newVersion := batch.Version + 1 - if err := c.store.GetBatchStore().UpdateState(ctx, batch.ID, batch.Version, newVersion, entity.BatchStateCancelled); err != nil { + if err := batchstate.UpdateState(ctx, c.store, batch, newVersion, entity.BatchStateCancelled); err != nil { metrics.NamedCounter(c.metricsScope, opName, "storage_errors", 1) return fmt.Errorf("failed to update batch %s state to cancelled: %w", batch.ID, err) } diff --git a/submitqueue/orchestrator/controller/speculate/speculate_test.go b/submitqueue/orchestrator/controller/speculate/speculate_test.go index 6b0a6b5f..d8042701 100644 --- a/submitqueue/orchestrator/controller/speculate/speculate_test.go +++ b/submitqueue/orchestrator/controller/speculate/speculate_test.go @@ -90,6 +90,17 @@ func runProcess(t *testing.T, ctrl *gomock.Controller, controller *Controller, b return controller.Process(context.Background(), delivery) } +func expectBatchStateTransition(ctrl *gomock.Controller, store *storagemock.MockStorage, batch entity.Batch, newState entity.BatchState) { + membershipStore := storagemock.NewMockBatchStateMembershipStore(ctrl) + if !newState.IsTerminal() { + membershipStore.EXPECT().Add(gomock.Any(), batch.Queue, newState, batch.ID).Return(nil).AnyTimes() + } + if !batch.State.IsTerminal() && batch.State != newState { + membershipStore.EXPECT().Remove(gomock.Any(), batch.Queue, batch.State, batch.ID).Return(nil).AnyTimes() + } + store.EXPECT().GetBatchStateMembershipStore().Return(membershipStore).AnyTimes() +} + func TestNewController(t *testing.T) { ctrl := gomock.NewController(t) store := storagemock.NewMockStorage(ctrl) @@ -123,6 +134,7 @@ func TestController_Process_StartSpeculation(t *testing.T) { store := storagemock.NewMockStorage(ctrl) store.EXPECT().GetBatchStore().Return(batchStore).AnyTimes() + expectBatchStateTransition(ctrl, store, batch, entity.BatchStateSpeculating) controller := newTestController(t, ctrl, store, nil) require.NoError(t, runProcess(t, ctrl, controller, batch.ID)) @@ -141,6 +153,7 @@ func TestController_Process_FinalizeNoDeps(t *testing.T) { store := storagemock.NewMockStorage(ctrl) store.EXPECT().GetBatchStore().Return(batchStore).AnyTimes() + expectBatchStateTransition(ctrl, store, batch, entity.BatchStateMerging) controller := newTestController(t, ctrl, store, nil) require.NoError(t, runProcess(t, ctrl, controller, batch.ID)) @@ -161,6 +174,7 @@ func TestController_Process_FinalizeAllDepsSucceeded(t *testing.T) { store := storagemock.NewMockStorage(ctrl) store.EXPECT().GetBatchStore().Return(batchStore).AnyTimes() + expectBatchStateTransition(ctrl, store, batch, entity.BatchStateMerging) controller := newTestController(t, ctrl, store, nil) require.NoError(t, runProcess(t, ctrl, controller, batch.ID)) @@ -199,6 +213,7 @@ func TestController_Process_FailedDepFailsBatch(t *testing.T) { store := storagemock.NewMockStorage(ctrl) store.EXPECT().GetBatchStore().Return(batchStore).AnyTimes() + expectBatchStateTransition(ctrl, store, batch, entity.BatchStateFailed) controller := newTestController(t, ctrl, store, nil) require.NoError(t, runProcess(t, ctrl, controller, batch.ID)) @@ -221,6 +236,7 @@ func TestController_Process_CancelledDepSkipped(t *testing.T) { store := storagemock.NewMockStorage(ctrl) store.EXPECT().GetBatchStore().Return(batchStore).AnyTimes() + expectBatchStateTransition(ctrl, store, batch, entity.BatchStateMerging) controller := newTestController(t, ctrl, store, nil) require.NoError(t, runProcess(t, ctrl, controller, batch.ID)) @@ -375,6 +391,7 @@ func TestController_Process_CancellingTerminalFlow(t *testing.T) { store.EXPECT().GetBatchStore().Return(batchStore).AnyTimes() store.EXPECT().GetBuildStore().Return(buildStore).AnyTimes() store.EXPECT().GetBatchDependentStore().Return(depStore).AnyTimes() + expectBatchStateTransition(ctrl, store, batch, entity.BatchStateCancelled) type pubRec struct { topic string @@ -438,6 +455,7 @@ func TestController_Process_CancellingBuildAlreadyTerminal(t *testing.T) { store.EXPECT().GetBatchStore().Return(batchStore).AnyTimes() store.EXPECT().GetBuildStore().Return(buildStore).AnyTimes() store.EXPECT().GetBatchDependentStore().Return(depStore).AnyTimes() + expectBatchStateTransition(ctrl, store, batch, entity.BatchStateCancelled) controller := newTestController(t, ctrl, store, nil) require.NoError(t, runProcess(t, ctrl, controller, batch.ID)) @@ -467,6 +485,7 @@ func TestController_Process_CancellingNoBuildYet(t *testing.T) { store.EXPECT().GetBatchStore().Return(batchStore).AnyTimes() store.EXPECT().GetBuildStore().Return(buildStore).AnyTimes() store.EXPECT().GetBatchDependentStore().Return(depStore).AnyTimes() + expectBatchStateTransition(ctrl, store, batch, entity.BatchStateCancelled) controller := newTestController(t, ctrl, store, nil) require.NoError(t, runProcess(t, ctrl, controller, batch.ID)) @@ -494,6 +513,7 @@ func TestController_Process_CancellingNoDependents(t *testing.T) { store.EXPECT().GetBatchStore().Return(batchStore).AnyTimes() store.EXPECT().GetBuildStore().Return(buildStore).AnyTimes() store.EXPECT().GetBatchDependentStore().Return(depStore).AnyTimes() + expectBatchStateTransition(ctrl, store, batch, entity.BatchStateCancelled) mockPub := queuemock.NewMockPublisher(ctrl) mockPub.EXPECT().Publish(gomock.Any(), "conclude", gomock.Any()).Return(nil).Times(1) @@ -534,6 +554,7 @@ func TestController_Process_CancellingTerminalCASVersionMismatch(t *testing.T) { store := storagemock.NewMockStorage(ctrl) store.EXPECT().GetBatchStore().Return(batchStore).AnyTimes() store.EXPECT().GetBuildStore().Return(buildStore).AnyTimes() + expectBatchStateTransition(ctrl, store, batch, entity.BatchStateCancelled) // BatchDependentStore must NOT be touched — terminal CAS failed before fan-out. // No publish expected (terminal CAS failed before fan-out). diff --git a/test/integration/submitqueue/extension/storage/suite.go b/test/integration/submitqueue/extension/storage/suite.go index 064800c9..94bd7b93 100644 --- a/test/integration/submitqueue/extension/storage/suite.go +++ b/test/integration/submitqueue/extension/storage/suite.go @@ -48,6 +48,12 @@ func (s *StorageContractSuite) SetStorage(store storage.Storage) { s.storage = store } +// GetStorage returns the storage instance under test, for implementation-specific +// suites that need to assert backend-internal behavior alongside the contract tests. +func (s *StorageContractSuite) GetStorage() storage.Storage { + return s.storage +} + // SetLogger sets the logger for tests func (s *StorageContractSuite) SetLogger(log *testutil.TestLogger) { s.log = log @@ -381,3 +387,92 @@ func (s *StorageContractSuite) TestStorage_ChangeCreate_EmptyDetails() { require.Len(t, got, 1) assert.Equal(t, entity.ChangeDetails{}, got[0].Details) } + +// membershipIDs returns sorted batch IDs for stable comparison. +func (s *StorageContractSuite) membershipIDs(queue string, state entity.BatchState) []string { + t := s.T() + ids, err := s.storage.GetBatchStateMembershipStore().ListIDs(s.ctx, queue, state) + require.NoError(t, err) + sort.Strings(ids) + return ids +} + +// TestStorage_BatchStateMembership_AddAndList verifies membership rows are +// listed by their queue/state key. +func (s *StorageContractSuite) TestStorage_BatchStateMembership_AddAndList() { + t := s.T() + ctx := s.ctx + const queue = "bsm-list" + + store := s.storage.GetBatchStateMembershipStore() + require.NoError(t, store.Add(ctx, queue, entity.BatchStateCreated, queue+"/batch/1")) + require.NoError(t, store.Add(ctx, queue, entity.BatchStateCreated, queue+"/batch/2")) + + assert.Equal(t, []string{queue + "/batch/1", queue + "/batch/2"}, s.membershipIDs(queue, entity.BatchStateCreated)) +} + +// TestStorage_BatchStateMembership_AddIdempotent verifies repeated Add calls do +// not duplicate rows. +func (s *StorageContractSuite) TestStorage_BatchStateMembership_AddIdempotent() { + t := s.T() + ctx := s.ctx + const queue = "bsm-idempotent" + + store := s.storage.GetBatchStateMembershipStore() + require.NoError(t, store.Add(ctx, queue, entity.BatchStateCreated, queue+"/batch/1")) + require.NoError(t, store.Add(ctx, queue, entity.BatchStateCreated, queue+"/batch/1")) + + assert.Equal(t, []string{queue + "/batch/1"}, s.membershipIDs(queue, entity.BatchStateCreated)) +} + +// TestStorage_BatchStateMembership_Remove verifies Remove is idempotent and +// deletes only the specified row. +func (s *StorageContractSuite) TestStorage_BatchStateMembership_Remove() { + t := s.T() + ctx := s.ctx + const queue = "bsm-remove" + + store := s.storage.GetBatchStateMembershipStore() + require.NoError(t, store.Add(ctx, queue, entity.BatchStateCreated, queue+"/batch/1")) + require.NoError(t, store.Add(ctx, queue, entity.BatchStateCreated, queue+"/batch/2")) + require.NoError(t, store.Remove(ctx, queue, entity.BatchStateCreated, queue+"/batch/1")) + require.NoError(t, store.Remove(ctx, queue, entity.BatchStateCreated, queue+"/batch/1")) + + assert.Equal(t, []string{queue + "/batch/2"}, s.membershipIDs(queue, entity.BatchStateCreated)) +} + +// TestStorage_BatchStateMembership_QueueScoped verifies ListIDs never returns +// rows from another queue. +func (s *StorageContractSuite) TestStorage_BatchStateMembership_QueueScoped() { + t := s.T() + ctx := s.ctx + const queueA = "bsm-scoped-a" + const queueB = "bsm-scoped-b" + + store := s.storage.GetBatchStateMembershipStore() + require.NoError(t, store.Add(ctx, queueA, entity.BatchStateCreated, queueA+"/batch/1")) + require.NoError(t, store.Add(ctx, queueB, entity.BatchStateCreated, queueB+"/batch/1")) + require.NoError(t, store.Add(ctx, queueB, entity.BatchStateCreated, queueB+"/batch/2")) + + assert.Equal(t, []string{queueA + "/batch/1"}, s.membershipIDs(queueA, entity.BatchStateCreated)) + assert.Equal(t, []string{queueB + "/batch/1", queueB + "/batch/2"}, s.membershipIDs(queueB, entity.BatchStateCreated)) +} + +// TestStorage_BatchStateMembership_StateScoped verifies ListIDs is scoped by state. +func (s *StorageContractSuite) TestStorage_BatchStateMembership_StateScoped() { + t := s.T() + ctx := s.ctx + const queue = "bsm-state-scoped" + + store := s.storage.GetBatchStateMembershipStore() + require.NoError(t, store.Add(ctx, queue, entity.BatchStateCreated, queue+"/batch/1")) + require.NoError(t, store.Add(ctx, queue, entity.BatchStateSpeculating, queue+"/batch/2")) + + assert.Equal(t, []string{queue + "/batch/1"}, s.membershipIDs(queue, entity.BatchStateCreated)) + assert.Equal(t, []string{queue + "/batch/2"}, s.membershipIDs(queue, entity.BatchStateSpeculating)) +} + +// TestStorage_BatchStateMembership_UnknownKey returns an empty set for a key with no rows. +func (s *StorageContractSuite) TestStorage_BatchStateMembership_UnknownKey() { + assert.Empty(s.T(), s.membershipIDs("bsm-does-not-exist", entity.BatchStateCreated)) +}