diff --git a/aggregation_mode/Cargo.lock b/aggregation_mode/Cargo.lock index 6e424297a..721b58879 100644 --- a/aggregation_mode/Cargo.lock +++ b/aggregation_mode/Cargo.lock @@ -9766,6 +9766,7 @@ dependencies = [ "base64 0.22.1", "bigdecimal", "bytes", + "chrono", "crc", "crossbeam-queue", "either", @@ -9843,6 +9844,7 @@ dependencies = [ "bitflags 2.10.0", "byteorder", "bytes", + "chrono", "crc", "digest 0.10.7", "dotenvy", @@ -9886,6 +9888,7 @@ dependencies = [ "bigdecimal", "bitflags 2.10.0", "byteorder", + "chrono", "crc", "dotenvy", "etcetera", @@ -9922,6 +9925,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c2d12fe70b2c1b4401038055f90f151b78208de1f9f89a7dbfd41587a10c3eea" dependencies = [ "atoi", + "chrono", "flume", "futures-channel", "futures-core", diff --git a/aggregation_mode/db/Cargo.toml b/aggregation_mode/db/Cargo.toml index 47b3b89e3..08ca38b76 100644 --- a/aggregation_mode/db/Cargo.toml +++ b/aggregation_mode/db/Cargo.toml @@ -5,8 +5,7 @@ edition = "2021" [dependencies] tokio = { version = "1"} -# TODO: enable tls -sqlx = { version = "0.8", features = [ "runtime-tokio", "postgres", "migrate" ] } +sqlx = { version = "0.8", features = [ "runtime-tokio", "postgres", "migrate", "chrono" ] } [[bin]] diff --git a/aggregation_mode/db/migrations/002_task_status_updated_at.sql b/aggregation_mode/db/migrations/002_task_status_updated_at.sql new file mode 100644 index 000000000..f11ae355e --- /dev/null +++ b/aggregation_mode/db/migrations/002_task_status_updated_at.sql @@ -0,0 +1 @@ +ALTER TABLE tasks add COLUMN status_updated_at TIMESTAMPTZ DEFAULT now(); diff --git a/aggregation_mode/db/src/types.rs b/aggregation_mode/db/src/types.rs index 676a780a4..bc5607108 100644 --- a/aggregation_mode/db/src/types.rs +++ b/aggregation_mode/db/src/types.rs @@ -1,6 +1,9 @@ use sqlx::{ prelude::FromRow, - types::{BigDecimal, Uuid}, + types::{ + chrono::{DateTime, Utc}, + BigDecimal, Uuid, + }, Type, }; @@ -21,6 +24,7 @@ pub struct Task { pub program_commitment: Vec, pub merkle_path: Option>, pub status: TaskStatus, + pub status_updated_at: Option>, } #[derive(Debug, Clone, FromRow)] diff --git a/aggregation_mode/proof_aggregator/src/backend/db.rs b/aggregation_mode/proof_aggregator/src/backend/db.rs index feeaa0db2..a4e0da092 100644 --- a/aggregation_mode/proof_aggregator/src/backend/db.rs +++ b/aggregation_mode/proof_aggregator/src/backend/db.rs @@ -23,7 +23,18 @@ impl Db { Ok(Self { pool }) } - pub async fn get_pending_tasks_and_mark_them_as_processing( + /// Fetches tasks that are ready to be processed and atomically updates their status. + /// + /// This function selects up to `limit` tasks for the given `proving_system_id` that are + /// either: + /// - in `pending` status, or + /// - in `processing` status but whose `status_updated_at` timestamp is older than 12 hours + /// (to recover tasks that may have been abandoned or stalled). + /// + /// The selected rows are locked using `FOR UPDATE SKIP LOCKED` to ensure safe concurrent + /// processing by multiple workers. All selected tasks have their status set to + /// `processing` and their `status_updated_at` updated to `now()` before being returned. + pub async fn get_tasks_to_process_and_update_their_status( &self, proving_system_id: i32, limit: i64, @@ -32,12 +43,19 @@ impl Db { "WITH selected AS ( SELECT task_id FROM tasks - WHERE proving_system_id = $1 AND status = 'pending' + WHERE proving_system_id = $1 + AND ( + status = 'pending' + OR ( + status = 'processing' + AND status_updated_at <= now() - interval '12 hours' + ) + ) LIMIT $2 FOR UPDATE SKIP LOCKED ) UPDATE tasks t - SET status = 'processing' + SET status = 'processing', status_updated_at = now() FROM selected s WHERE t.task_id = s.task_id RETURNING t.*;", @@ -61,7 +79,7 @@ impl Db { for (task_id, merkle_path) in updates { if let Err(e) = sqlx::query( - "UPDATE tasks SET merkle_path = $1, status = 'verified', proof = NULL WHERE task_id = $2", + "UPDATE tasks SET merkle_path = $1, status = 'verified', status_updated_at = now(), proof = NULL WHERE task_id = $2", ) .bind(merkle_path) .bind(task_id) @@ -83,6 +101,20 @@ impl Db { Ok(()) } - // TODO: this should be used when rolling back processing proofs on unexpected errors - pub async fn mark_tasks_as_pending(&self) {} + pub async fn mark_tasks_as_pending(&self, tasks_id: &[Uuid]) -> Result<(), DbError> { + if tasks_id.is_empty() { + return Ok(()); + } + + sqlx::query( + "UPDATE tasks SET status = 'pending', status_updated_at = now() + WHERE task_id = ANY($1) AND status = 'processing'", + ) + .bind(tasks_id) + .execute(&self.pool) + .await + .map_err(|e| DbError::Query(e.to_string()))?; + + Ok(()) + } } diff --git a/aggregation_mode/proof_aggregator/src/backend/fetcher.rs b/aggregation_mode/proof_aggregator/src/backend/fetcher.rs index 5ed1df6de..8eedf4a13 100644 --- a/aggregation_mode/proof_aggregator/src/backend/fetcher.rs +++ b/aggregation_mode/proof_aggregator/src/backend/fetcher.rs @@ -30,7 +30,7 @@ impl ProofsFetcher { ) -> Result<(Vec, Vec), ProofsFetcherError> { let tasks = self .db - .get_pending_tasks_and_mark_them_as_processing(engine.proving_system_id() as i32, limit) + .get_tasks_to_process_and_update_their_status(engine.proving_system_id() as i32, limit) .await .map_err(ProofsFetcherError::Query)?; diff --git a/aggregation_mode/proof_aggregator/src/backend/mod.rs b/aggregation_mode/proof_aggregator/src/backend/mod.rs index 9dee2164c..2609889f4 100644 --- a/aggregation_mode/proof_aggregator/src/backend/mod.rs +++ b/aggregation_mode/proof_aggregator/src/backend/mod.rs @@ -119,7 +119,23 @@ impl ProofAggregator { info!("Starting proof aggregator service"); info!("About to aggregate and submit proof to be verified on chain"); - let res = self.aggregate_and_submit_proofs_on_chain().await; + + let (proofs, tasks_id) = match self + .fetcher + .fetch_pending_proofs(self.engine.clone(), self.config.total_proofs_limit as i64) + .await + .map_err(AggregatedProofSubmissionError::FetchingProofs) + { + Ok(res) => res, + Err(e) => { + error!("Error while aggregating and submitting proofs: {:?}", e); + return; + } + }; + + let res = self + .aggregate_and_submit_proofs_on_chain((proofs, &tasks_id)) + .await; match res { Ok(()) => { @@ -127,20 +143,18 @@ impl ProofAggregator { } Err(err) => { error!("Error while aggregating and submitting proofs: {:?}", err); + warn!("Marking tasks back to pending after failure"); + if let Err(e) = self.db.mark_tasks_as_pending(&tasks_id).await { + error!("Error while marking proofs to pending again: {:?}", e); + }; } } } - // TODO: on failure, mark proofs as pending again async fn aggregate_and_submit_proofs_on_chain( &mut self, + (proofs, tasks_id): (Vec, &[Uuid]), ) -> Result<(), AggregatedProofSubmissionError> { - let (proofs, tasks_id) = self - .fetcher - .fetch_pending_proofs(self.engine.clone(), self.config.total_proofs_limit as i64) - .await - .map_err(AggregatedProofSubmissionError::FetchingProofs)?; - if proofs.is_empty() { warn!("No proofs collected, skipping aggregation..."); return Ok(()); @@ -215,7 +229,7 @@ impl ProofAggregator { info!("Storing merkle paths for each task...",); let mut merkle_paths_for_tasks: Vec<(Uuid, Vec)> = vec![]; - for (idx, task_id) in tasks_id.into_iter().enumerate() { + for (idx, task_id) in tasks_id.iter().enumerate() { let Some(proof) = merkle_tree.get_proof_by_pos(idx) else { warn!("Proof not found for task id {task_id}"); continue; @@ -226,7 +240,7 @@ impl ProofAggregator { .flat_map(|e| e.to_vec()) .collect::>(); - merkle_paths_for_tasks.push((task_id, proof_bytes)) + merkle_paths_for_tasks.push((*task_id, proof_bytes)) } self.db .insert_tasks_merkle_path_and_mark_them_as_verified(merkle_paths_for_tasks)