diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/main.rs | 128 |
1 files changed, 84 insertions, 44 deletions
diff --git a/src/main.rs b/src/main.rs index 220c69e..6bb7488 100644 --- a/src/main.rs +++ b/src/main.rs @@ -6,9 +6,12 @@ use std::sync::Arc; use anyhow::{anyhow, Error}; use axum::{Json, Router, routing::get}; use axum::extract::{Query, State}; +use futures_util::future::join_all; use futures_util::TryStreamExt; use serde_derive::{Deserialize, Serialize}; use sqlx::{Column, Row, SqlitePool}; +use sqlx::sqlite::SqlitePoolOptions; +use tokio::task::JoinError; use tracing::{debug, info}; use tracing_subscriber::EnvFilter; @@ -33,7 +36,8 @@ impl RelationReader { if cfg.fields.is_empty() { return Err(anyhow!("relation does not have any field")); } - match SqlitePool::connect(&cfg.connect).await { + // TODO make this configurable + match SqlitePoolOptions::new().max_connections(50).connect(&cfg.connect).await { Ok(p) => Ok(RelationReader { cfg: cfg.clone(), db: p, @@ -213,7 +217,7 @@ async fn query( // visit a cloned snapshot, updates will be reflected at once in the next loop round for task in unvisited.clone() { let (field, value) = (&task.field, &task.value); - let mut relations = match state.field_relations.get(field) { + let relations = match state.field_relations.get(field) { Some(v) => (*v).clone(), None => { return Json(QueryResponse { @@ -223,55 +227,91 @@ async fn query( }); } }; - for rel in relations.iter_mut() { - info!("visit: relation {}, field {}, value {}", rel.cfg.name, field, value); - // ensure every (relation, field, value) is visited only once - if !visited.insert(RelationFieldValue { - relation: rel.cfg.name.clone(), - field: field.clone(), - value: value.clone(), - }) { - continue; - } - let result = match rel.query(field, value).await { - Ok(v) => v, + + struct AsyncQueryResult { + rel: RelationReader, + result: HashMap<String, String>, + } + + let field1 = field.clone(); + let field2 = field.clone(); + let value1 = value.clone(); + let value2 = value.clone(); + let tasks: Vec<Result<Result<AsyncQueryResult, (AsyncQueryResult, Error)>, JoinError>> = + join_all(relations.into_iter().filter(|rel| { + info!("visit: relation {}, field {}, value {}", rel.cfg.name, &field1, &value1); + // ensure every (relation, field, value) is visited only once + return visited.insert(RelationFieldValue { + relation: rel.cfg.name.clone(), + field: field1.clone(), + value: value1.clone(), + }); + }).map(|mut rel| { + let field2 = field2.clone(); + let value2 = value2.clone(); + return tokio::spawn(async move { + return match rel.query(&field2, &value2).await { + Ok(v) => Ok(AsyncQueryResult { + rel: rel.clone(), + result: v, + }), + Err(why) => Err((AsyncQueryResult { + rel: rel.clone(), + result: Default::default(), + }, why)), + }; + }); + }).collect::<Vec<_>>()).await; + for t in tasks { + match t { + Ok(v) => match v { + Ok(v) => { + for (field, value) in v.result { + if let Some(set) = all_result.get_mut(&field) { + set.insert(value.clone()); + } else { + let mut s = HashSet::new(); + s.insert(value.clone()); + all_result.insert(field.clone(), s); + } + let v = RelationFieldValue { + relation: v.rel.cfg.name.clone(), + field: field.clone(), + value, + }; + // skip non-distinct fields to prevent generating irrelevant results + if !state.fields.get(&field).expect("missing field info").distinct { + continue; + } + if visited.contains(&v) { + // don't re-add visited values + continue; + } + unvisited.insert(v); + } + unvisited.remove(&RelationFieldValue { + relation: v.rel.cfg.name.clone(), + field: field.clone(), + value: value.clone(), + }); + } + Err((r, why)) => { + return Json(QueryResponse { + success: false, + message: format!("failed to query relation `{r}` with field `{f}`, value `{v}`: {why}", + r = &r.rel.cfg.name, f = field, v = value), + data: Default::default(), + }); + } + }, Err(why) => { return Json(QueryResponse { success: false, - message: format!("failed to query relation `{r}` with field `{f}`, value `{v}`: {why}", - r = &rel.cfg.name, f = field, v = value), + message: format!("failed to join query task: {}", why), data: Default::default(), }); } }; - for (field, value) in result { - if let Some(set) = all_result.get_mut(&field) { - set.insert(value.clone()); - } else { - let mut s = HashSet::new(); - s.insert(value.clone()); - all_result.insert(field.clone(), s); - } - let v = RelationFieldValue { - relation: rel.cfg.name.clone(), - field: field.clone(), - value, - }; - // skip non-distinct fields to prevent generating irrelevant results - if !state.fields.get(&field).expect("missing field info").distinct { - continue; - } - if visited.contains(&v) { - // don't re-add visited values - continue; - } - unvisited.insert(v); - } - unvisited.remove(&RelationFieldValue { - relation: rel.cfg.name.clone(), - field: field.clone(), - value: value.clone(), - }); } } depth += 1; |