From 9f444382bbb87a09e027b5b7002e948e52ae6c4c Mon Sep 17 00:00:00 2001 From: Akhilesh Sharma Date: Sun, 19 Jan 2025 18:25:43 -0800 Subject: [PATCH 1/7] added feature flag for azure --- compose.yaml | 4 +-- core/Cargo.toml | 3 ++ core/migrations/.gitkeep | 2 ++ core/src/lib.rs | 1 - core/src/models/chunkr/task.rs | 6 +++- core/src/models/chunkr/upload.rs | 42 +++++++++++++++++++++++ core/src/pipeline/segmentation_and_ocr.rs | 10 ++++-- 7 files changed, 62 insertions(+), 6 deletions(-) diff --git a/compose.yaml b/compose.yaml index 1bd68cac1..ee461c238 100644 --- a/compose.yaml +++ b/compose.yaml @@ -108,7 +108,7 @@ services: env_file: - .env deploy: - replicas: 1 + replicas: 0 restart: always task: @@ -121,7 +121,7 @@ services: env_file: - .env deploy: - replicas: 1 + replicas: 0 restart: always web: diff --git a/core/Cargo.toml b/core/Cargo.toml index f034b2d7f..fa33fd923 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -4,6 +4,9 @@ version = "0.0.0" edition = "2021" default-run = "core" +[features] +azure = [] + [dependencies] actix-cors = "0.7.0" actix-multipart = "0.7.2" diff --git a/core/migrations/.gitkeep b/core/migrations/.gitkeep index e69de29bb..5e2d6d3bb 100644 --- a/core/migrations/.gitkeep +++ b/core/migrations/.gitkeep @@ -0,0 +1,2 @@ +ALTER TABLE TASKS + ADD COLUMN version TEXT; diff --git a/core/src/lib.rs b/core/src/lib.rs index 2c7b9a1cd..279805c0b 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -25,7 +25,6 @@ pub mod pipeline; pub mod routes; pub mod utils; -use configs::worker_config; use jobs::init::init_jobs; use middleware::auth::AuthMiddlewareFactory; use routes::github::get_github_repo_info; diff --git a/core/src/models/chunkr/task.rs b/core/src/models/chunkr/task.rs index 5e1227d9f..e7091b857 100644 --- a/core/src/models/chunkr/task.rs +++ b/core/src/models/chunkr/task.rs @@ -3,6 +3,8 @@ use crate::models::chunkr::chunk_processing::ChunkProcessing; use crate::models::chunkr::output::{OutputResponse, Segment, SegmentType}; use crate::models::chunkr::segment_processing::SegmentProcessing; use crate::models::chunkr::structured_extraction::JsonSchema; +#[cfg(feature = "azure")] +use crate::models::chunkr::upload::Pipeline; use crate::models::chunkr::upload::{OcrStrategy, SegmentationStrategy}; use crate::utils::clients::get_pg_client; use crate::utils::services::file_operations::check_file_type; @@ -585,7 +587,6 @@ pub struct Configuration { /// Whether to use high-resolution images for cropping and post-processing. pub high_resolution: bool, pub json_schema: Option, - #[serde(skip_serializing_if = "Option::is_none")] #[deprecated] pub model: Option, pub ocr_strategy: OcrStrategy, @@ -595,6 +596,9 @@ pub struct Configuration { #[deprecated] /// The target number of words in each chunk. If 0, each chunk will contain a single segment. pub target_chunk_length: Option, + #[cfg(feature = "azure")] + #[serde(skip_serializing_if = "Option::is_none")] + pub pipeline: Option, } // TODO: Move to output diff --git a/core/src/models/chunkr/upload.rs b/core/src/models/chunkr/upload.rs index 01be4e91f..3dedd6a25 100644 --- a/core/src/models/chunkr/upload.rs +++ b/core/src/models/chunkr/upload.rs @@ -12,6 +12,25 @@ use serde::{Deserialize, Serialize}; use strum_macros::{Display, EnumString}; use utoipa::{IntoParams, ToSchema}; +#[cfg_attr( + feature = "azure", + derive( + Debug, + Serialize, + Deserialize, + PartialEq, + Clone, + ToSql, + FromSql, + ToSchema, + Display, + EnumString, + ) +)] +pub enum Pipeline { + Azure, +} + #[derive(Debug, MultipartForm, ToSchema, IntoParams)] #[into_params(parameter_in = Query)] pub struct CreateForm { @@ -51,6 +70,13 @@ pub struct CreateForm { /// The target chunk length to be used for chunking. /// If 0, each chunk will contain a single segment. pub target_chunk_length: Option>, + #[cfg(feature = "azure")] + #[param(style = Form, value_type = Option)] + #[schema(value_type = Option)] + /// The pipeline to use for processing. + /// If pipeline is set to Azure then Azure layout analysis will be used for segmentation and OCR. + /// The output will be unified to the Chunkr `output` format. + pub pipeline: Option>, } impl CreateForm { @@ -149,6 +175,11 @@ impl CreateForm { .unwrap_or(SegmentationStrategy::LayoutAnalysis) } + #[cfg(feature = "azure")] + fn get_pipeline(&self) -> Option { + self.pipeline.as_ref().map(|e| e.0.clone()) + } + pub fn to_configuration(&self) -> Configuration { Configuration { chunk_processing: self.get_chunk_processing(), @@ -160,6 +191,8 @@ impl CreateForm { segment_processing: self.get_segment_processing(), segmentation_strategy: self.get_segmentation_strategy(), target_chunk_length: None, + #[cfg(feature = "azure")] + pipeline: self.get_pipeline(), } } } @@ -191,6 +224,13 @@ pub struct UpdateForm { #[param(style = Form, value_type = Option)] #[schema(value_type = Option)] pub segmentation_strategy: Option>, + #[cfg(feature = "azure")] + #[param(style = Form, value_type = Option)] + #[schema(value_type = Option)] + /// The pipeline to use for processing. + /// If pipeline is set to Azure then Azure layout analysis will be used for segmentation and OCR. + /// The output will be unified to the Chunkr output. + pub pipeline: Option>, } impl UpdateForm { @@ -276,6 +316,8 @@ impl UpdateForm { .map(|e| e.0.clone()) .unwrap_or(current_config.segmentation_strategy.clone()), target_chunk_length: None, + #[cfg(feature = "azure")] + pipeline: None, } } } diff --git a/core/src/pipeline/segmentation_and_ocr.rs b/core/src/pipeline/segmentation_and_ocr.rs index beb04c829..c0c89c3e8 100644 --- a/core/src/pipeline/segmentation_and_ocr.rs +++ b/core/src/pipeline/segmentation_and_ocr.rs @@ -86,8 +86,14 @@ pub async fn process(pipeline: &mut Pipeline) -> Result<(), Box = pipeline.page_images.as_ref().unwrap().iter().map(|x| x.as_ref()).collect(); - + let pages: Vec<_> = pipeline + .page_images + .as_ref() + .unwrap() + .iter() + .map(|x| x.as_ref()) + .collect(); + let page_segments = match process_pages_batch( &pages, pipeline.get_task()?.configuration.clone(), From 8706ab454aa43466b14f8d9147bb538408c5ef73 Mon Sep 17 00:00:00 2001 From: Akhilesh Sharma Date: Sun, 19 Jan 2025 18:57:29 -0800 Subject: [PATCH 2/7] updated task worker to accept azure as pipeline type --- core/src/models/chunkr/task.rs | 21 ++++++++++++ core/src/models/chunkr/upload.rs | 21 ++---------- core/src/pipeline/azure.rs | 39 +++++++++++++++++++++ core/src/pipeline/chunking.rs | 41 +++++++++++++++++++++++ core/src/pipeline/mod.rs | 3 ++ core/src/pipeline/segmentation_and_ocr.rs | 23 ++++--------- core/src/utils/services/chunking.rs | 22 ++++++------ core/src/workers/task.rs | 40 ++++++++++++++++------ 8 files changed, 154 insertions(+), 56 deletions(-) create mode 100644 core/src/pipeline/azure.rs create mode 100644 core/src/pipeline/chunking.rs diff --git a/core/src/models/chunkr/task.rs b/core/src/models/chunkr/task.rs index e7091b857..be795e7bd 100644 --- a/core/src/models/chunkr/task.rs +++ b/core/src/models/chunkr/task.rs @@ -554,6 +554,8 @@ pub struct TaskResponse { pub task_url: Option, } +//TODO: Move to configuration + #[derive( Debug, Clone, @@ -576,6 +578,25 @@ pub enum Status { Cancelled, } +#[cfg_attr( + feature = "azure", + derive( + Debug, + Serialize, + Deserialize, + PartialEq, + Clone, + ToSql, + FromSql, + ToSchema, + Display, + EnumString, + ) +)] +pub enum Pipeline { + Azure, +} + #[derive(Debug, Serialize, Deserialize, Clone, ToSql, FromSql, ToSchema)] /// The configuration used for the task. pub struct Configuration { diff --git a/core/src/models/chunkr/upload.rs b/core/src/models/chunkr/upload.rs index 3dedd6a25..50cbee857 100644 --- a/core/src/models/chunkr/upload.rs +++ b/core/src/models/chunkr/upload.rs @@ -5,6 +5,8 @@ use crate::models::chunkr::chunk_processing::{ use crate::models::chunkr::segment_processing::SegmentProcessing; use crate::models::chunkr::structured_extraction::JsonSchema; use crate::models::chunkr::task::Configuration; +#[cfg(feature = "azure")] +use crate::models::chunkr::task::Pipeline; use actix_multipart::form::json::Json as MPJson; use actix_multipart::form::{tempfile::TempFile, text::Text, MultipartForm}; use postgres_types::{FromSql, ToSql}; @@ -12,25 +14,6 @@ use serde::{Deserialize, Serialize}; use strum_macros::{Display, EnumString}; use utoipa::{IntoParams, ToSchema}; -#[cfg_attr( - feature = "azure", - derive( - Debug, - Serialize, - Deserialize, - PartialEq, - Clone, - ToSql, - FromSql, - ToSchema, - Display, - EnumString, - ) -)] -pub enum Pipeline { - Azure, -} - #[derive(Debug, MultipartForm, ToSchema, IntoParams)] #[into_params(parameter_in = Query)] pub struct CreateForm { diff --git a/core/src/pipeline/azure.rs b/core/src/pipeline/azure.rs new file mode 100644 index 000000000..b78b1bad7 --- /dev/null +++ b/core/src/pipeline/azure.rs @@ -0,0 +1,39 @@ +use crate::models::chunkr::output::Segment; +use crate::models::chunkr::pipeline::Pipeline; +use crate::models::chunkr::task::Status; +use crate::utils::services::chunking; + +/// Use Azure document layout analysis to perform segmentation and ocr +pub async fn process(pipeline: &mut Pipeline) -> Result<(), Box> { + pipeline + .get_task()? + .update( + Some(Status::Processing), + Some("Chunking".to_string()), + None, + None, + None, + None, + ) + .await?; + + let segments: Vec = pipeline + .output + .chunks + .clone() + .into_iter() + .map(|c| c.segments) + .flatten() + .collect(); + + let chunk_processing = pipeline.get_task()?.configuration.chunk_processing.clone(); + + let chunks = chunking::hierarchical_chunking( + segments, + chunk_processing.target_length, + chunk_processing.ignore_headers_and_footers, + )?; + + pipeline.output.chunks = chunks; + Ok(()) +} diff --git a/core/src/pipeline/chunking.rs b/core/src/pipeline/chunking.rs new file mode 100644 index 000000000..70da8dd3a --- /dev/null +++ b/core/src/pipeline/chunking.rs @@ -0,0 +1,41 @@ +use crate::models::chunkr::output::Segment; +use crate::models::chunkr::pipeline::Pipeline; +use crate::models::chunkr::task::Status; +use crate::utils::services::chunking; + +/// Chunk the segments +/// +/// This function will perform chunking on the segments +pub async fn process(pipeline: &mut Pipeline) -> Result<(), Box> { + pipeline + .get_task()? + .update( + Some(Status::Processing), + Some("Chunking".to_string()), + None, + None, + None, + None, + ) + .await?; + + let segments: Vec = pipeline + .output + .chunks + .clone() + .into_iter() + .map(|c| c.segments) + .flatten() + .collect(); + + let chunk_processing = pipeline.get_task()?.configuration.chunk_processing.clone(); + + let chunks = chunking::hierarchical_chunking( + segments, + chunk_processing.target_length, + chunk_processing.ignore_headers_and_footers, + )?; + + pipeline.output.chunks = chunks; + Ok(()) +} diff --git a/core/src/pipeline/mod.rs b/core/src/pipeline/mod.rs index 90963b547..8dcf73d92 100644 --- a/core/src/pipeline/mod.rs +++ b/core/src/pipeline/mod.rs @@ -1,3 +1,6 @@ +#[cfg(feature = "azure")] +pub mod azure; +pub mod chunking; pub mod convert_to_images; pub mod crop; pub mod segment_processing; diff --git a/core/src/pipeline/segmentation_and_ocr.rs b/core/src/pipeline/segmentation_and_ocr.rs index c0c89c3e8..21f4dcad2 100644 --- a/core/src/pipeline/segmentation_and_ocr.rs +++ b/core/src/pipeline/segmentation_and_ocr.rs @@ -1,8 +1,7 @@ -use crate::models::chunkr::output::{BoundingBox, OCRResult, Segment, SegmentType}; +use crate::models::chunkr::output::{BoundingBox, Chunk, OCRResult, Segment, SegmentType}; use crate::models::chunkr::pipeline::Pipeline; use crate::models::chunkr::task::{Configuration, Status}; use crate::models::chunkr::upload::{OcrStrategy, SegmentationStrategy}; -use crate::utils::services::chunking; use crate::utils::services::images; use crate::utils::services::ocr; use crate::utils::services::pdf; @@ -120,21 +119,11 @@ pub async fn process(pipeline: &mut Pipeline) -> Result<(), Box Result, Box> { let mut chunks: Vec = Vec::new(); - if target_length == 0 || target_length == 1 { - for segment in segments { - chunks.push(Chunk::new(vec![segment.clone()])); - } - return Ok(chunks); - } - let mut current_segments: Vec = Vec::new(); let mut current_word_count = 0; @@ -364,17 +357,26 @@ mod tests { // First chunk: Title 1 + Section 1 + text assert_eq!(chunks[0].segments.len(), 3); assert_eq!(chunks[0].segments[0].segment_type, SegmentType::Title); - assert_eq!(chunks[0].segments[1].segment_type, SegmentType::SectionHeader); + assert_eq!( + chunks[0].segments[1].segment_type, + SegmentType::SectionHeader + ); assert_eq!(chunks[0].segments[2].segment_type, SegmentType::Text); // Second chunk: Section 2 + text assert_eq!(chunks[1].segments.len(), 2); - assert_eq!(chunks[1].segments[0].segment_type, SegmentType::SectionHeader); + assert_eq!( + chunks[1].segments[0].segment_type, + SegmentType::SectionHeader + ); assert_eq!(chunks[1].segments[1].segment_type, SegmentType::Text); // Third chunk: Title 2 + Section 3 assert_eq!(chunks[2].segments.len(), 2); assert_eq!(chunks[2].segments[0].segment_type, SegmentType::Title); - assert_eq!(chunks[2].segments[1].segment_type, SegmentType::SectionHeader); + assert_eq!( + chunks[2].segments[1].segment_type, + SegmentType::SectionHeader + ); } } diff --git a/core/src/workers/task.rs b/core/src/workers/task.rs index a20bf4cad..7aa4b84f6 100644 --- a/core/src/workers/task.rs +++ b/core/src/workers/task.rs @@ -4,6 +4,7 @@ use core::models::chunkr::pipeline::Pipeline; use core::models::chunkr::task::Status; use core::models::chunkr::task::TaskPayload; use core::models::rrq::queue::QueuePayload; +use core::pipeline::chunking; use core::pipeline::convert_to_images; use core::pipeline::crop; use core::pipeline::segment_processing; @@ -21,6 +22,9 @@ async fn execute_step( println!("Executing step: {}", step); let start = std::time::Instant::now(); match step { + #[cfg(feature = "azure")] + "azure" => azure::process(pipeline).await, + "chunking" => chunking::process(pipeline).await, "convert_to_images" => convert_to_images::process(pipeline).await, "crop" => crop::process(pipeline).await, "segmentation_and_ocr" => segmentation_and_ocr::process(pipeline).await, @@ -42,15 +46,31 @@ async fn execute_step( /// Orchestrate the task /// /// This function defines the order of the steps in the pipeline. -fn orchestrate_task() -> Vec<&'static str> { - vec![ - "update_metadata", - "convert_to_images", - "segmentation_and_ocr", - "crop", - "segment_processing", - "structured_extraction", - ] +fn orchestrate_task( + pipeline: &mut Pipeline, +) -> Result, Box> { + let mut steps = vec!["update_metadata", "convert_to_images"]; + #[cfg(feature = "azure")] + { + match pipeline.get_task()?.configuration.pipeline.clone() { + core::models::task::pipeline::Pipeline::Azure => steps.push("azure"), + _ => steps.push("segmentation_and_ocr"), + } + } + #[cfg(not(feature = "azure"))] + { + steps.push("segmentation_and_ocr"); + } + let chunk_processing = pipeline.get_task()?.configuration.chunk_processing.clone(); + if chunk_processing.target_length == 0 || chunk_processing.target_length == 1 { + steps.push("chunking"); + } + steps.push("segment_processing"); + let json_schema = pipeline.get_task()?.configuration.json_schema.clone(); + if json_schema.is_some() { + steps.push("structured_extraction"); + } + Ok(steps) } pub async fn process(payload: QueuePayload) -> Result<(), Box> { @@ -67,7 +87,7 @@ pub async fn process(payload: QueuePayload) -> Result<(), Box Date: Mon, 20 Jan 2025 14:34:46 -0800 Subject: [PATCH 3/7] azure layout analysis works --- core/.gitignore | 4 +- core/src/configs/azure_config.rs | 36 ++ core/src/configs/mod.rs | 3 + core/src/models/chunkr/azure.rs | 577 +++++++++++++++++++++ core/src/models/chunkr/mod.rs | 3 + core/src/models/chunkr/pipeline.rs | 5 + core/src/models/chunkr/task.rs | 86 +-- core/src/models/chunkr/upload.rs | 20 +- core/src/pipeline/azure.rs | 24 +- core/src/pipeline/chunking.rs | 1 + core/src/pipeline/convert_to_images.rs | 1 + core/src/pipeline/crop.rs | 1 + core/src/pipeline/mod.rs | 5 +- core/src/pipeline/segment_processing.rs | 9 +- core/src/pipeline/segmentation_and_ocr.rs | 13 +- core/src/pipeline/structured_extraction.rs | 7 +- core/src/pipeline/update_metadata.rs | 58 +-- core/src/utils/routes/update_task.rs | 1 + core/src/utils/services/azure.rs | 102 ++++ core/src/utils/services/mod.rs | 3 + core/src/workers/task.rs | 8 +- 21 files changed, 843 insertions(+), 124 deletions(-) create mode 100644 core/src/configs/azure_config.rs create mode 100644 core/src/models/chunkr/azure.rs create mode 100644 core/src/utils/services/azure.rs diff --git a/core/.gitignore b/core/.gitignore index 0c0bf4ac4..1bf6dd0d0 100644 --- a/core/.gitignore +++ b/core/.gitignore @@ -40,4 +40,6 @@ cmd.txt pdfium-binaries -.cargo \ No newline at end of file +.cargo + +azure-analysis-response.json \ No newline at end of file diff --git a/core/src/configs/azure_config.rs b/core/src/configs/azure_config.rs new file mode 100644 index 000000000..c7ff6bd31 --- /dev/null +++ b/core/src/configs/azure_config.rs @@ -0,0 +1,36 @@ +use config::{Config as ConfigTrait, ConfigError}; +use dotenvy::dotenv_override; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Serialize, Deserialize)] +pub struct Config { + #[serde(default = "default_api_version")] + pub api_version: String, + pub endpoint: String, + pub key: String, + #[serde(default = "default_model_id")] + pub model_id: String, +} + +fn default_api_version() -> String { + "2024-11-30".to_string() +} + +fn default_model_id() -> String { + "prebuilt-layout".to_string() +} + +impl Config { + pub fn from_env() -> Result { + dotenv_override().ok(); + + ConfigTrait::builder() + .add_source( + config::Environment::default() + .prefix("AZURE") + .separator("__"), + ) + .build()? + .try_deserialize() + } +} diff --git a/core/src/configs/mod.rs b/core/src/configs/mod.rs index 211e15e92..f1aa8d7a6 100644 --- a/core/src/configs/mod.rs +++ b/core/src/configs/mod.rs @@ -12,3 +12,6 @@ pub mod stripe_config; pub mod throttle_config; pub mod user_config; pub mod worker_config; + +#[cfg(feature = "azure")] +pub mod azure_config; diff --git a/core/src/models/chunkr/azure.rs b/core/src/models/chunkr/azure.rs new file mode 100644 index 000000000..7d0d8328d --- /dev/null +++ b/core/src/models/chunkr/azure.rs @@ -0,0 +1,577 @@ +use crate::models::chunkr::output::{BoundingBox, Chunk, OCRResult, Segment, SegmentType}; +use crate::utils::services::html::convert_table_to_markdown; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use std::error::Error; + +#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct AzureAnalysisResponse { + pub status: String, + pub created_date_time: Option, + pub last_updated_date_time: Option, + pub analyze_result: Option, +} + +#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct AnalyzeResult { + pub api_version: Option, + pub model_id: Option, + pub string_index_type: Option, + pub content: Option, + pub pages: Option>, + pub tables: Option>, + pub paragraphs: Option>, + pub styles: Option>, + pub content_format: Option, + pub sections: Option>, + pub figures: Option>, +} + +#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct Span { + pub offset: Option, + pub length: Option, +} + +#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct Page { + pub page_number: Option, + pub angle: Option, + pub width: Option, + pub height: Option, + pub unit: Option, + pub words: Option>, + pub selection_marks: Option>, + pub lines: Option>, + pub spans: Option>, +} + +#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct Word { + pub content: Option, + pub polygon: Option>, + pub confidence: Option, + pub span: Option, +} + +#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SelectionMark { + pub state: Option, + pub polygon: Option>, + pub confidence: Option, + pub span: Option, +} + +#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct Line { + pub content: Option, + pub polygon: Option>, + pub spans: Option>, +} + +#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct Table { + pub row_count: Option, + pub column_count: Option, + pub cells: Option>, + pub bounding_regions: Option>, + pub spans: Option>, + pub caption: Option, +} + +#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct Cell { + pub kind: Option, + pub row_index: Option, + pub column_index: Option, + pub content: Option, + pub bounding_regions: Option>, + pub spans: Option>, + #[serde(default)] + pub elements: Option>, +} + +#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct BoundingRegion { + pub page_number: Option, + pub polygon: Option>, +} + +#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct Caption { + pub content: Option, + pub bounding_regions: Option>, + pub spans: Option>, + pub elements: Option>, +} + +#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct Paragraph { + pub spans: Option>, + pub bounding_regions: Option>, + pub role: Option, + pub content: Option, +} + +#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct Section { + pub spans: Option>, + pub elements: Option>, +} + +#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct Figure { + pub id: String, + pub bounding_regions: Vec, + pub spans: Option>, + pub elements: Option>, + pub caption: Option, +} + +impl AzureAnalysisResponse { + pub fn to_chunks(&self) -> Result, Box> { + let mut all_segments: Vec = Vec::new(); + + if let Some(analyze_result) = &self.analyze_result { + if let Some(paragraphs) = &analyze_result.paragraphs { + let mut replacements: std::collections::BTreeMap> = + std::collections::BTreeMap::new(); + let mut skip_paragraphs = std::collections::HashSet::new(); + + let page_dimensions = if let Some(pages) = &analyze_result.pages { + pages + .iter() + .map(|page| { + let page_number = page.page_number.unwrap_or(1) as u32; + let (width, height) = match page.unit.as_deref() { + Some("inch") => ( + inches_to_pixels(page.width.unwrap_or(0.0) as f64), + inches_to_pixels(page.height.unwrap_or(0.0) as f64), + ), + _ => (0.0, 0.0), + }; + (page_number, (width, height)) + }) + .collect::>() + } else { + std::collections::HashMap::new() + }; + + if let Some(tables) = &analyze_result.tables { + for table in tables { + let mut min_paragraph_idx = usize::MAX; + + if let Some(cells) = &table.cells { + for cell in cells { + if let Some(elements) = &cell.elements { + for element in elements { + if let Some(idx) = extract_paragraph_index(element) { + min_paragraph_idx = min_paragraph_idx.min(idx); + skip_paragraphs.insert(idx); + } + } + } + } + } + + if let Some(regions) = &table.bounding_regions { + if let Some(first_region) = regions.first() { + let page_number = first_region.page_number.unwrap_or(1) as u32; + let (page_width, page_height) = page_dimensions + .get(&page_number) + .copied() + .unwrap_or((0.0, 0.0)); + + let bbox = create_bounding_box(first_region); + let segment = Segment { + bbox, + confidence: None, + content: table_to_text(table), + html: Some(table_to_html(table)), + markdown: Some(table_to_markdown(table)), + image: None, + llm: None, + ocr: Vec::new(), + page_height, + page_width, + page_number, + segment_id: uuid::Uuid::new_v4().to_string(), + segment_type: SegmentType::Table, + }; + + if min_paragraph_idx != usize::MAX { + replacements + .entry(min_paragraph_idx) + .or_insert_with(Vec::new) + .push(segment); + } + } + } + + if let Some(caption) = &table.caption { + process_caption( + caption, + &mut replacements, + &mut skip_paragraphs, + &page_dimensions, + ); + } + } + } + + if let Some(figures) = &analyze_result.figures { + for figure in figures { + let mut min_paragraph_idx = usize::MAX; + + if let Some(elements) = &figure.elements { + for element in elements { + if let Some(idx) = extract_paragraph_index(element) { + min_paragraph_idx = min_paragraph_idx.min(idx); + skip_paragraphs.insert(idx); + } + } + } + + if !figure.bounding_regions.is_empty() { + let first_region = &figure.bounding_regions[0]; + let page_number = first_region.page_number.unwrap_or(1) as u32; + let (page_width, page_height) = page_dimensions + .get(&page_number) + .copied() + .unwrap_or((0.0, 0.0)); + + let bbox = create_bounding_box(first_region); + let segment = Segment { + bbox, + confidence: None, + content: String::new(), + html: None, + markdown: None, + image: None, + llm: None, + ocr: Vec::new(), + page_height, + page_width, + page_number, + segment_id: uuid::Uuid::new_v4().to_string(), + segment_type: SegmentType::Picture, + }; + + if min_paragraph_idx != usize::MAX { + replacements + .entry(min_paragraph_idx) + .or_insert_with(Vec::new) + .push(segment); + } + } + + if let Some(caption) = &figure.caption { + process_caption( + caption, + &mut replacements, + &mut skip_paragraphs, + &page_dimensions, + ); + } + } + } + + for (idx, paragraph) in paragraphs.iter().enumerate() { + if skip_paragraphs.contains(&idx) { + if let Some(replacement_segments) = replacements.get(&idx) { + all_segments.extend(replacement_segments.clone()); + } + continue; + } + + if let Some(regions) = ¶graph.bounding_regions { + if let Some(first_region) = regions.first() { + let page_number = first_region.page_number.unwrap_or(1) as u32; + let (page_width, page_height) = page_dimensions + .get(&page_number) + .copied() + .unwrap_or((0.0, 0.0)); + + let bbox = create_bounding_box(first_region); + let segment_type = match paragraph.role.as_deref() { + Some("title") => SegmentType::Title, + Some("sectionHeading") => SegmentType::SectionHeader, + Some("pageHeader") => SegmentType::PageHeader, + Some("pageNumber") => SegmentType::PageFooter, + Some("pageFooter") => SegmentType::PageFooter, + _ => SegmentType::Text, + }; + + let segment = Segment { + bbox, + confidence: None, + content: paragraph + .content + .clone() + .unwrap_or_default() + .replace(":selected:", ""), + html: None, + markdown: None, + image: None, + llm: None, + ocr: Vec::new(), + page_height, + page_width, + page_number, + segment_id: uuid::Uuid::new_v4().to_string(), + segment_type, + }; + all_segments.push(segment); + } + } + } + + // Assign OCR words to segments based on intersection area + if let Some(pages) = &analyze_result.pages { + for page in pages { + let page_number = page.page_number.unwrap_or(1) as u32; + + if let Some(words) = &page.words { + for word in words { + if let (Some(polygon), Some(content), Some(confidence)) = + (&word.polygon, &word.content, &word.confidence) + { + let word_bbox = create_word_bbox(polygon)?; + let mut max_area = 0.0; + let mut best_segment_idx = None; + + for (idx, segment) in all_segments.iter().enumerate() { + if segment.page_number == page_number { + let area = segment.bbox.intersection_area(&word_bbox); + if area > max_area { + max_area = area; + best_segment_idx = Some(idx); + } + } + } + + if let Some(idx) = best_segment_idx { + let segment = &all_segments[idx]; + let relative_bbox = BoundingBox::new( + word_bbox.left - segment.bbox.left, + word_bbox.top - segment.bbox.top, + word_bbox.width, + word_bbox.height, + ); + + all_segments[idx].ocr.push(OCRResult { + text: content.clone().replace(":selected:", ""), + confidence: Some(*confidence as f32), + bbox: relative_bbox, + }); + } + } + } + } + } + } + } + } + + Ok(all_segments + .into_iter() + .map(|segment| Chunk::new(vec![segment])) + .collect()) + } +} + +fn process_caption( + caption: &Caption, + replacements: &mut std::collections::BTreeMap>, + skip_paragraphs: &mut std::collections::HashSet, + page_dimensions: &std::collections::HashMap, +) { + if let Some(elements) = &caption.elements { + if let Some(first_idx) = elements.first().and_then(|e| extract_paragraph_index(e)) { + for element in elements { + if let Some(idx) = extract_paragraph_index(element) { + skip_paragraphs.insert(idx); + } + } + + if let Some(regions) = &caption.bounding_regions { + if let Some(first_region) = regions.first() { + let page_number = first_region.page_number.unwrap_or(1) as u32; + let (page_width, page_height) = page_dimensions + .get(&page_number) + .copied() + .unwrap_or((0.0, 0.0)); + + let bbox = create_bounding_box(first_region); + let segment = Segment { + bbox, + confidence: None, + content: caption + .content + .clone() + .unwrap_or_default() + .replace(":selected:", ""), + html: None, + markdown: None, + image: None, + llm: None, + ocr: Vec::new(), + page_height, + page_width, + page_number, + segment_id: uuid::Uuid::new_v4().to_string(), + segment_type: SegmentType::Caption, + }; + replacements + .entry(first_idx) + .or_insert_with(Vec::new) + .push(segment); + } + } + } + } +} + +fn extract_paragraph_index(element: &str) -> Option { + element.strip_prefix("/paragraphs/")?.parse::().ok() +} + +fn inches_to_pixels(inches: f64) -> f32 { + (inches * 72.0) as f32 +} + +fn create_bounding_box(region: &BoundingRegion) -> BoundingBox { + if let Some(polygon) = ®ion.polygon { + if polygon.len() >= 8 { + let points: Vec = polygon + .iter() + .map(|&coord| inches_to_pixels(coord)) + .collect(); + + let left = points + .iter() + .step_by(2) + .fold(f32::INFINITY, |acc, &x| acc.min(x)); + let top = points + .iter() + .skip(1) + .step_by(2) + .fold(f32::INFINITY, |acc, &y| acc.min(y)); + let right = points + .iter() + .step_by(2) + .fold(f32::NEG_INFINITY, |acc, &x| acc.max(x)); + let bottom = points + .iter() + .skip(1) + .step_by(2) + .fold(f32::NEG_INFINITY, |acc, &y| acc.max(y)); + + return BoundingBox::new(left, top, right - left, bottom - top); + } + } + BoundingBox::new(0.0, 0.0, 0.0, 0.0) +} + +fn table_to_text(table: &Table) -> String { + table + .cells + .as_ref() + .map(|cells| { + cells + .iter() + .filter_map(|cell| cell.content.as_ref().map(|s| s.replace(":selected:", ""))) + .collect::>() + .join(" ") + }) + .unwrap_or_default() +} + +fn table_to_html(table: &Table) -> String { + let cells = match &table.cells { + Some(cells) => cells, + None => return String::new(), + }; + + let row_count = table.row_count.unwrap_or(0) as usize; + let col_count = table.column_count.unwrap_or(0) as usize; + if row_count == 0 || col_count == 0 { + return String::new(); + } + + let mut html = String::from(""); + + let mut grid = vec![vec![None; col_count]; row_count]; + for cell in cells { + if let (Some(row), Some(col), Some(content)) = + (cell.row_index, cell.column_index, cell.content.as_ref()) + { + if (row as usize) < row_count && (col as usize) < col_count { + grid[row as usize][col as usize] = Some(content.replace(":selected:", "")); + } + } + } + + for row in grid { + html.push_str(""); + for cell in row { + html.push_str(""); + } + html.push_str(""); + } + + html.push_str("
"); + html.push_str(cell.as_deref().unwrap_or("")); + html.push_str("
"); + html +} + +fn table_to_markdown(table: &Table) -> String { + convert_table_to_markdown(table_to_html(table)) +} + +fn create_word_bbox(polygon: &[f64]) -> Result> { + if polygon.len() >= 8 { + let points: Vec = polygon + .iter() + .map(|&coord| inches_to_pixels(coord)) + .collect(); + + let left = points + .iter() + .step_by(2) + .fold(f32::INFINITY, |acc, &x| acc.min(x)); + let top = points + .iter() + .skip(1) + .step_by(2) + .fold(f32::INFINITY, |acc, &y| acc.min(y)); + let right = points + .iter() + .step_by(2) + .fold(f32::NEG_INFINITY, |acc, &x| acc.max(x)); + let bottom = points + .iter() + .skip(1) + .step_by(2) + .fold(f32::NEG_INFINITY, |acc, &y| acc.max(y)); + + Ok(BoundingBox::new(left, top, right - left, bottom - top)) + } else { + Err("Invalid polygon length".into()) + } +} diff --git a/core/src/models/chunkr/mod.rs b/core/src/models/chunkr/mod.rs index 454aee5f5..634494983 100644 --- a/core/src/models/chunkr/mod.rs +++ b/core/src/models/chunkr/mod.rs @@ -12,3 +12,6 @@ pub mod task; pub mod tasks; pub mod upload; pub mod user; + +#[cfg(feature = "azure")] +pub mod azure; diff --git a/core/src/models/chunkr/pipeline.rs b/core/src/models/chunkr/pipeline.rs index 3b2dca4f9..6f842ebde 100644 --- a/core/src/models/chunkr/pipeline.rs +++ b/core/src/models/chunkr/pipeline.rs @@ -43,6 +43,7 @@ impl Pipeline { None, None, None, + None, ) .await?; } @@ -76,6 +77,7 @@ impl Pipeline { Some(Status::Processing), Some("Task started".to_string()), None, + None, Some(Utc::now()), None, None, @@ -125,6 +127,7 @@ impl Pipeline { None, None, None, + None, ) .await?; } @@ -148,6 +151,7 @@ impl Pipeline { Some(status), message, None, + None, Some(finished_at), expires_at, None, @@ -163,6 +167,7 @@ impl Pipeline { message, None, None, + None, Some(finished_at), expires_at, ) diff --git a/core/src/models/chunkr/task.rs b/core/src/models/chunkr/task.rs index be795e7bd..6e2bc0f19 100644 --- a/core/src/models/chunkr/task.rs +++ b/core/src/models/chunkr/task.rs @@ -1,10 +1,10 @@ use crate::configs::worker_config; use crate::models::chunkr::chunk_processing::ChunkProcessing; use crate::models::chunkr::output::{OutputResponse, Segment, SegmentType}; -use crate::models::chunkr::segment_processing::SegmentProcessing; +use crate::models::chunkr::segment_processing::{ + GenerationStrategy, PictureGenerationConfig, SegmentProcessing, +}; use crate::models::chunkr::structured_extraction::JsonSchema; -#[cfg(feature = "azure")] -use crate::models::chunkr::upload::Pipeline; use crate::models::chunkr::upload::{OcrStrategy, SegmentationStrategy}; use crate::utils::clients::get_pg_client; use crate::utils::services::file_operations::check_file_type; @@ -222,16 +222,30 @@ impl Task { let temp_file = download_to_tempfile(&self.output_location, None).await?; let json_content: String = tokio::fs::read_to_string(temp_file.path()).await?; let mut output_response: OutputResponse = serde_json::from_str(&json_content)?; - async fn process(segment: &mut Segment) -> Result> { + let picture_generation_config: PictureGenerationConfig = self + .configuration + .segment_processing + .picture + .clone() + .ok_or(format!("Picture generation config not found"))?; + async fn process( + segment: &mut Segment, + picture_generation_config: &PictureGenerationConfig, + ) -> Result> { let url = generate_presigned_url(segment.image.as_ref().unwrap(), true, None) .await .ok(); if segment.segment_type == SegmentType::Picture { - segment.html = Some(format!( - "", - url.clone().unwrap_or_default() - )); - segment.markdown = Some(format!("![Image]({})", url.clone().unwrap_or_default())); + if picture_generation_config.html == GenerationStrategy::Auto { + segment.html = Some(format!( + "", + url.clone().unwrap_or_default() + )); + } + if picture_generation_config.markdown == GenerationStrategy::Auto { + segment.markdown = + Some(format!("![Image]({})", url.clone().unwrap_or_default())); + } } Ok(url.clone().unwrap_or_default()) } @@ -240,7 +254,7 @@ impl Task { .iter_mut() .flat_map(|chunk| chunk.segments.iter_mut()) .filter(|segment| segment.image.is_some()) - .map(|segment| process(segment)); + .map(|segment| process(segment, &picture_generation_config)); try_join_all(futures).await?; Ok(output_response) @@ -251,6 +265,7 @@ impl Task { status: Option, message: Option, configuration: Option, + page_count: Option, started_at: Option>, finished_at: Option>, expires_at: Option>, @@ -278,6 +293,11 @@ impl Task { self.expires_at = Some(dt); } + if let Some(page_count) = page_count { + update_parts.push(format!("page_count = {}", page_count)); + self.page_count = Some(page_count); + } + if let Some(configuration) = configuration { update_parts.push(format!( "configuration = '{}'", @@ -293,9 +313,25 @@ impl Task { self.user_id ); - client.execute(&query, &[]).await?; - - Ok(()) + match client.execute(&query, &[]).await { + Ok(_) => Ok(()), + Err(e) => { + if e.to_string().contains("usage limit exceeded") { + Box::pin(self.update( + Some(Status::Failed), + Some("Page limit exceeded".to_string()), + None, + None, + None, + None, + None, + )) + .await + } else { + Err(Box::new(e)) + } + } + } } pub async fn delete(&self) -> Result<(), Box> { @@ -345,6 +381,7 @@ impl Task { Some("Finishing up".to_string()), None, None, + None, Some(Utc::now()), None, ) @@ -480,7 +517,7 @@ impl Task { .await .map_err(|_| "Error getting input file url")?; let mut pdf_url = None; - let mut output = None; + let mut output: Option = None; if self.status == Status::Succeeded { pdf_url = Some( generate_presigned_url(&self.pdf_location, true, None) @@ -578,22 +615,11 @@ pub enum Status { Cancelled, } -#[cfg_attr( - feature = "azure", - derive( - Debug, - Serialize, - Deserialize, - PartialEq, - Clone, - ToSql, - FromSql, - ToSchema, - Display, - EnumString, - ) +#[cfg(feature = "azure")] +#[derive( + Debug, Serialize, Deserialize, PartialEq, Clone, ToSql, FromSql, ToSchema, Display, EnumString, )] -pub enum Pipeline { +pub enum PipelineType { Azure, } @@ -619,7 +645,7 @@ pub struct Configuration { pub target_chunk_length: Option, #[cfg(feature = "azure")] #[serde(skip_serializing_if = "Option::is_none")] - pub pipeline: Option, + pub pipeline: Option, } // TODO: Move to output diff --git a/core/src/models/chunkr/upload.rs b/core/src/models/chunkr/upload.rs index 50cbee857..7710131d2 100644 --- a/core/src/models/chunkr/upload.rs +++ b/core/src/models/chunkr/upload.rs @@ -6,7 +6,7 @@ use crate::models::chunkr::segment_processing::SegmentProcessing; use crate::models::chunkr::structured_extraction::JsonSchema; use crate::models::chunkr::task::Configuration; #[cfg(feature = "azure")] -use crate::models::chunkr::task::Pipeline; +use crate::models::chunkr::task::PipelineType; use actix_multipart::form::json::Json as MPJson; use actix_multipart::form::{tempfile::TempFile, text::Text, MultipartForm}; use postgres_types::{FromSql, ToSql}; @@ -54,12 +54,12 @@ pub struct CreateForm { /// If 0, each chunk will contain a single segment. pub target_chunk_length: Option>, #[cfg(feature = "azure")] - #[param(style = Form, value_type = Option)] - #[schema(value_type = Option)] - /// The pipeline to use for processing. + #[param(style = Form, value_type = Option)] + #[schema(value_type = Option)] + /// The PipelineType to use for processing. /// If pipeline is set to Azure then Azure layout analysis will be used for segmentation and OCR. /// The output will be unified to the Chunkr `output` format. - pub pipeline: Option>, + pub pipeline: Option>, } impl CreateForm { @@ -159,7 +159,7 @@ impl CreateForm { } #[cfg(feature = "azure")] - fn get_pipeline(&self) -> Option { + fn get_pipeline(&self) -> Option { self.pipeline.as_ref().map(|e| e.0.clone()) } @@ -208,12 +208,12 @@ pub struct UpdateForm { #[schema(value_type = Option)] pub segmentation_strategy: Option>, #[cfg(feature = "azure")] - #[param(style = Form, value_type = Option)] - #[schema(value_type = Option)] + #[param(style = Form, value_type = Option)] + #[schema(value_type = Option)] /// The pipeline to use for processing. /// If pipeline is set to Azure then Azure layout analysis will be used for segmentation and OCR. /// The output will be unified to the Chunkr output. - pub pipeline: Option>, + pub pipeline: Option>, } impl UpdateForm { @@ -300,7 +300,7 @@ impl UpdateForm { .unwrap_or(current_config.segmentation_strategy.clone()), target_chunk_length: None, #[cfg(feature = "azure")] - pipeline: None, + pipeline: self.pipeline.as_ref().map(|e| e.0.clone()), } } } diff --git a/core/src/pipeline/azure.rs b/core/src/pipeline/azure.rs index b78b1bad7..8999b2baf 100644 --- a/core/src/pipeline/azure.rs +++ b/core/src/pipeline/azure.rs @@ -1,7 +1,6 @@ -use crate::models::chunkr::output::Segment; use crate::models::chunkr::pipeline::Pipeline; use crate::models::chunkr::task::Status; -use crate::utils::services::chunking; +use crate::utils::services::azure::perform_azure_analysis; /// Use Azure document layout analysis to perform segmentation and ocr pub async fn process(pipeline: &mut Pipeline) -> Result<(), Box> { @@ -9,7 +8,8 @@ pub async fn process(pipeline: &mut Pipeline) -> Result<(), Box Result<(), Box = pipeline - .output - .chunks - .clone() - .into_iter() - .map(|c| c.segments) - .flatten() - .collect(); - - let chunk_processing = pipeline.get_task()?.configuration.chunk_processing.clone(); - - let chunks = chunking::hierarchical_chunking( - segments, - chunk_processing.target_length, - chunk_processing.ignore_headers_and_footers, - )?; + let pdf_file = pipeline.pdf_file.as_ref().ok_or("PDF file not found")?; + let chunks = perform_azure_analysis(&pdf_file).await?; pipeline.output.chunks = chunks; Ok(()) diff --git a/core/src/pipeline/chunking.rs b/core/src/pipeline/chunking.rs index 70da8dd3a..ff3549cea 100644 --- a/core/src/pipeline/chunking.rs +++ b/core/src/pipeline/chunking.rs @@ -16,6 +16,7 @@ pub async fn process(pipeline: &mut Pipeline) -> Result<(), Box Result<(), Box Result<(), Box> { None, None, None, + None, ) .await?; diff --git a/core/src/pipeline/mod.rs b/core/src/pipeline/mod.rs index 8dcf73d92..e2ddcb4de 100644 --- a/core/src/pipeline/mod.rs +++ b/core/src/pipeline/mod.rs @@ -1,5 +1,3 @@ -#[cfg(feature = "azure")] -pub mod azure; pub mod chunking; pub mod convert_to_images; pub mod crop; @@ -7,3 +5,6 @@ pub mod segment_processing; pub mod segmentation_and_ocr; pub mod structured_extraction; pub mod update_metadata; + +#[cfg(feature = "azure")] +pub mod azure; diff --git a/core/src/pipeline/segment_processing.rs b/core/src/pipeline/segment_processing.rs index fa5c7f4b0..7a058f590 100644 --- a/core/src/pipeline/segment_processing.rs +++ b/core/src/pipeline/segment_processing.rs @@ -293,8 +293,12 @@ async fn process_segment( ) )?; - segment.html = Some(html); - segment.markdown = Some(markdown); + if segment.html.is_none() { + segment.html = Some(html); + } + if segment.markdown.is_none() { + segment.markdown = Some(markdown); + } segment.llm = llm; Ok(()) } @@ -313,6 +317,7 @@ pub async fn process(pipeline: &mut Pipeline) -> Result<(), Box Result<(), Box Result<(), Box Result<(), Box> { None, None, None, + None, ) .await?; - let texts: Vec = output_response.chunks + let texts: Vec = output_response + .chunks .iter() .map(|chunk| { - chunk.segments + chunk + .segments .iter() .map(|segment| segment.content.clone()) .collect::>() diff --git a/core/src/pipeline/update_metadata.rs b/core/src/pipeline/update_metadata.rs index cccfc62a9..d5d45e59d 100644 --- a/core/src/pipeline/update_metadata.rs +++ b/core/src/pipeline/update_metadata.rs @@ -1,6 +1,5 @@ use crate::models::chunkr::pipeline::Pipeline; use crate::models::chunkr::task::Status; -use crate::utils::clients::get_pg_client; use crate::utils::services::pdf::count_pages; use std::error::Error; @@ -8,49 +7,20 @@ use std::error::Error; /// /// This function calculates the page count for the task and updates the database with the page count pub async fn process(pipeline: &mut Pipeline) -> Result<(), Box> { - pipeline - .get_task()? - .update( - Some(Status::Processing), - Some("Counting pages".to_string()), - None, - None, - None, - None, - ) - .await?; - let client = get_pg_client().await?; - let pdf_file = pipeline.pdf_file.as_ref().unwrap(); let mut task = pipeline.get_task()?; - let task_id = task.task_id.clone(); + task.update( + Some(Status::Processing), + Some("Counting pages".to_string()), + None, + None, + None, + None, + None, + ) + .await?; + let pdf_file = pipeline.pdf_file.as_ref().unwrap(); let page_count = count_pages(pdf_file)?; - let task_query = format!( - "UPDATE tasks SET page_count = {} WHERE task_id = '{}'", - page_count, task_id - ); - match client.execute(&task_query, &[]).await { - Ok(_) => { - task.page_count = Some(page_count); - pipeline.task = Some(task.clone()); - Ok(()) - } - Err(e) => { - if e.to_string().contains("usage limit exceeded") { - pipeline - .get_task()? - .update( - Some(Status::Failed), - Some("Page limit exceeded".to_string()), - None, - None, - None, - None, - ) - .await?; - Ok(()) - } else { - Err(Box::new(e)) - } - } - } + task.update(None, None, None, Some(page_count), None, None, None) + .await?; + Ok(()) } diff --git a/core/src/utils/routes/update_task.rs b/core/src/utils/routes/update_task.rs index 8d63c2ec0..50b6b8acf 100644 --- a/core/src/utils/routes/update_task.rs +++ b/core/src/utils/routes/update_task.rs @@ -25,6 +25,7 @@ pub async fn update_task( None, None, None, + None, ) .await?; diff --git a/core/src/utils/services/azure.rs b/core/src/utils/services/azure.rs new file mode 100644 index 000000000..52f1abfbe --- /dev/null +++ b/core/src/utils/services/azure.rs @@ -0,0 +1,102 @@ +use crate::configs::azure_config; +use crate::models::chunkr::azure::AzureAnalysisResponse; +use crate::models::chunkr::output::Chunk; +use crate::utils::clients; +use crate::utils::retry::retry_with_backoff; +use base64::{engine::general_purpose, Engine as _}; +use serde_json; +use std::error::Error; +use std::fs; +use tempfile::NamedTempFile; + +async fn azure_analysis(temp_file: &NamedTempFile) -> Result, Box> { + let azure_config = azure_config::Config::from_env()?; + let api_version = azure_config.api_version; + let endpoint = azure_config.endpoint; + let key = azure_config.key; + let model_id = azure_config.model_id; + let client = clients::get_reqwest_client(); + + let url = format!( + "{}/documentintelligence/documentModels/{}:analyze?_overload=analyzeDocument&api-version={}", + endpoint.trim_end_matches('/'), + model_id, + api_version + ); + + let file_content = fs::read(temp_file.path())?; + let base64_content = general_purpose::STANDARD.encode(&file_content); + + let request_body = serde_json::json!({ + "base64Source": base64_content + }); + + let response = client + .post(&url) + .header("Ocp-Apim-Subscription-Key", key.clone()) + .json(&request_body) + .send() + .await?; + + if response.status() == 202 { + let operation_location = response + .headers() + .get("operation-location") + .ok_or("No operation-location header found")? + .to_str()?; + + loop { + tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; + + let status_response = client + .get(operation_location) + .header("Ocp-Apim-Subscription-Key", key.clone()) + .send() + .await? + .error_for_status()?; + + let azure_response: AzureAnalysisResponse = status_response.json().await?; + + match azure_response.status.as_str() { + "succeeded" => { + let chunks = azure_response.to_chunks()?; + return Ok(chunks); + } + "failed" => return Err("Analysis failed".into()), + "running" | "notStarted" => { + tokio::time::sleep(tokio::time::Duration::from_millis(500)).await; + continue; + } + _ => return Err("Unknown status".into()), + } + } + } + + Err("Unknown status".into()) +} + +pub async fn perform_azure_analysis( + temp_file: &NamedTempFile, +) -> Result, Box> { + Ok(retry_with_backoff(|| async { azure_analysis(temp_file).await }).await?) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::utils::clients::initialize; + use std::path::Path; + + #[tokio::test] + async fn test_azure_analysis() { + initialize().await; + let test_file_path = Path::new("./input/test.pdf"); + let temp_file = NamedTempFile::new().unwrap(); + fs::copy(test_file_path, temp_file.path()).unwrap(); + let result = azure_analysis(&temp_file) + .await + .expect("Azure analysis failed"); + let json = serde_json::to_string_pretty(&result).unwrap(); + fs::write("azure-analysis-response.json", json).unwrap(); + } +} diff --git a/core/src/utils/services/mod.rs b/core/src/utils/services/mod.rs index c11bcf4d1..ef7356341 100644 --- a/core/src/utils/services/mod.rs +++ b/core/src/utils/services/mod.rs @@ -11,3 +11,6 @@ pub mod pdf; pub mod search; pub mod segmentation; pub mod structured_extraction; + +#[cfg(feature = "azure")] +pub mod azure; diff --git a/core/src/workers/task.rs b/core/src/workers/task.rs index 7aa4b84f6..4cbb1b2ab 100644 --- a/core/src/workers/task.rs +++ b/core/src/workers/task.rs @@ -4,6 +4,9 @@ use core::models::chunkr::pipeline::Pipeline; use core::models::chunkr::task::Status; use core::models::chunkr::task::TaskPayload; use core::models::rrq::queue::QueuePayload; + +#[cfg(feature = "azure")] +use core::pipeline::azure; use core::pipeline::chunking; use core::pipeline::convert_to_images; use core::pipeline::crop; @@ -53,7 +56,7 @@ fn orchestrate_task( #[cfg(feature = "azure")] { match pipeline.get_task()?.configuration.pipeline.clone() { - core::models::task::pipeline::Pipeline::Azure => steps.push("azure"), + Some(core::models::chunkr::task::PipelineType::Azure) => steps.push("azure"), _ => steps.push("segmentation_and_ocr"), } } @@ -62,9 +65,10 @@ fn orchestrate_task( steps.push("segmentation_and_ocr"); } let chunk_processing = pipeline.get_task()?.configuration.chunk_processing.clone(); - if chunk_processing.target_length == 0 || chunk_processing.target_length == 1 { + if chunk_processing.target_length > 1 { steps.push("chunking"); } + steps.push("crop"); steps.push("segment_processing"); let json_schema = pipeline.get_task()?.configuration.json_schema.clone(); if json_schema.is_some() { From 7b3007a0dadd84a587189ba84c4cfa0b76f1fc83 Mon Sep 17 00:00:00 2001 From: Akhilesh Sharma Date: Mon, 20 Jan 2025 14:39:55 -0800 Subject: [PATCH 4/7] updated package --- clients/python-client/src/chunkr_ai/api/config.py | 3 +++ clients/python-client/src/chunkr_ai/models.py | 2 ++ clients/python-client/tests/test_chunkr.py | 15 +++++++++++++++ core/src/pipeline/azure.rs | 2 +- 4 files changed, 21 insertions(+), 1 deletion(-) diff --git a/clients/python-client/src/chunkr_ai/api/config.py b/clients/python-client/src/chunkr_ai/api/config.py index 0b8f16a54..444e33d1e 100644 --- a/clients/python-client/src/chunkr_ai/api/config.py +++ b/clients/python-client/src/chunkr_ai/api/config.py @@ -127,6 +127,8 @@ class Model(str, Enum): FAST = "Fast" HIGH_QUALITY = "HighQuality" +class PipelineType(str, Enum): + AZURE = "Azure" class Configuration(BaseModel): chunk_processing: Optional[ChunkProcessing] = Field(default=None) @@ -139,6 +141,7 @@ class Configuration(BaseModel): ocr_strategy: Optional[OcrStrategy] = Field(default=None) segment_processing: Optional[SegmentProcessing] = Field(default=None) segmentation_strategy: Optional[SegmentationStrategy] = Field(default=None) + pipeline: Optional[PipelineType] = Field(default=None) @model_validator(mode="before") def map_deprecated_fields(cls, values: Dict) -> Dict: diff --git a/clients/python-client/src/chunkr_ai/models.py b/clients/python-client/src/chunkr_ai/models.py index dacd3a07e..64ff027cc 100644 --- a/clients/python-client/src/chunkr_ai/models.py +++ b/clients/python-client/src/chunkr_ai/models.py @@ -18,6 +18,7 @@ SegmentType, SegmentationStrategy, Status, + PipelineType, ) from .api.task import TaskResponse @@ -45,4 +46,5 @@ "Status", "TaskResponse", "TaskResponseAsync", + "PipelineType", ] diff --git a/clients/python-client/tests/test_chunkr.py b/clients/python-client/tests/test_chunkr.py index 8f6d92ac7..ecdea2773 100644 --- a/clients/python-client/tests/test_chunkr.py +++ b/clients/python-client/tests/test_chunkr.py @@ -10,6 +10,7 @@ GenerationConfig, JsonSchema, OcrStrategy, + PipelineType, Property, SegmentationStrategy, SegmentProcessing, @@ -411,3 +412,17 @@ async def test_update_task_direct(chunkr_client, sample_path): assert task.status == "Succeeded" assert task.output is not None assert task.configuration.segmentation_strategy == SegmentationStrategy.PAGE + + +@pytest.mark.asyncio +async def test_pipeline_type(chunkr_client, sample_path): + client_type, client = chunkr_client + response = ( + await client.upload(sample_path, Configuration(pipeline=PipelineType.AZURE)) + if client_type == "async" + else client.upload(sample_path, Configuration(pipeline=PipelineType.AZURE)) + ) + + assert response.task_id is not None + assert response.status == "Succeeded" + assert response.output is not None diff --git a/core/src/pipeline/azure.rs b/core/src/pipeline/azure.rs index 8999b2baf..8c98033c7 100644 --- a/core/src/pipeline/azure.rs +++ b/core/src/pipeline/azure.rs @@ -8,7 +8,7 @@ pub async fn process(pipeline: &mut Pipeline) -> Result<(), Box Date: Mon, 20 Jan 2025 17:30:43 -0800 Subject: [PATCH 5/7] Added switch for Azure pipeline to Upload component --- apps/web/src/components/Upload/UploadMain.tsx | 34 +++++++++++++++++++ apps/web/src/models/taskConfig.model.ts | 8 +++++ apps/web/src/models/upload.model.ts | 4 +++ 3 files changed, 46 insertions(+) diff --git a/apps/web/src/components/Upload/UploadMain.tsx b/apps/web/src/components/Upload/UploadMain.tsx index 3a8a78d53..89333b6e8 100644 --- a/apps/web/src/components/Upload/UploadMain.tsx +++ b/apps/web/src/components/Upload/UploadMain.tsx @@ -6,6 +6,7 @@ import { SegmentationStrategy, DEFAULT_UPLOAD_CONFIG, DEFAULT_SEGMENT_PROCESSING, + Pipeline, } from "../../models/taskConfig.model"; import "./UploadMain.css"; import Upload from "./Upload"; @@ -111,6 +112,39 @@ export default function UploadMain({ className={`config-section ${!isAuthenticated ? "disabled" : ""}`} >
+ + + + + Pipeline + + } + value={config.pipeline || "Default"} + onChange={(value) => + setConfig({ + ...config, + pipeline: + value === "Default" ? undefined : (value as Pipeline), + }) + } + options={[ + { label: "Default", value: "Default" }, + { label: "Azure", value: Pipeline.Azure }, + ]} + /> diff --git a/apps/web/src/models/taskConfig.model.ts b/apps/web/src/models/taskConfig.model.ts index c77cc252c..5d0bd9826 100644 --- a/apps/web/src/models/taskConfig.model.ts +++ b/apps/web/src/models/taskConfig.model.ts @@ -178,6 +178,13 @@ export interface UploadFormData { * @default 512 */ target_chunk_length?: number; + + /** Pipeline to run after processing */ + pipeline?: Pipeline; +} + +export enum Pipeline { + Azure = "Azure", } const DEFAULT_SEGMENT_CONFIG: SegmentProcessingConfig = { @@ -227,4 +234,5 @@ export const DEFAULT_UPLOAD_CONFIG: UploadFormData = { segment_processing: DEFAULT_SEGMENT_PROCESSING, json_schema: undefined, // or some default schema if needed file: new File([], ""), + pipeline: undefined, // Default pipeline }; diff --git a/apps/web/src/models/upload.model.ts b/apps/web/src/models/upload.model.ts index 8c03e8cda..ef9b25371 100644 --- a/apps/web/src/models/upload.model.ts +++ b/apps/web/src/models/upload.model.ts @@ -4,6 +4,7 @@ import { OcrStrategy, SegmentProcessing, SegmentationStrategy, + Pipeline, } from "./taskConfig.model"; export interface UploadForm { @@ -30,4 +31,7 @@ export interface UploadForm { /** Strategy for document segmentation */ segmentation_strategy?: SegmentationStrategy; + + /** Pipeline to run after processing */ + pipeline?: Pipeline; } From 741180af7314cd241e9a140c759475b73b9df8cc Mon Sep 17 00:00:00 2001 From: Mehul Chadda Date: Mon, 20 Jan 2025 20:45:49 -0800 Subject: [PATCH 6/7] Added feature flag based ADI functionality --- apps/web/.env.example | 4 + apps/web/src/components/Upload/UploadMain.tsx | 73 ++++++++++--------- apps/web/src/config/env.config.ts | 37 ++++++++++ apps/web/src/models/taskConfig.model.ts | 6 +- apps/web/src/models/upload.model.ts | 3 +- clients/node-client/jest.load.config.js | 6 ++ 6 files changed, 93 insertions(+), 36 deletions(-) create mode 100644 apps/web/src/config/env.config.ts create mode 100644 clients/node-client/jest.load.config.js diff --git a/apps/web/.env.example b/apps/web/.env.example index e94b4dca2..dfc944906 100644 --- a/apps/web/.env.example +++ b/apps/web/.env.example @@ -1,6 +1,10 @@ VITE_API_URL= +# VITE_API_KEY= + VITE_KEYCLOAK_URL= VITE_KEYCLOAK_REALM= VITE_KEYCLOAK_CLIENT_ID= VITE_KEYCLOAK_REDIRECT_URI=http://localhost:5173 VITE_KEYCLOAK_POST_LOGOUT_REDIRECT_URI=http://localhost:5173 + +VITE_FEATURE_FLAG_PIPELINE=false # true enables Azure Document Intelligence layout analysis, OCR and segment processing heuristics diff --git a/apps/web/src/components/Upload/UploadMain.tsx b/apps/web/src/components/Upload/UploadMain.tsx index 89333b6e8..bb444372f 100644 --- a/apps/web/src/components/Upload/UploadMain.tsx +++ b/apps/web/src/components/Upload/UploadMain.tsx @@ -18,6 +18,7 @@ import { } from "./ConfigControls"; import { uploadFile } from "../../services/uploadFileApi"; import { UploadForm } from "../../models/upload.model"; +import { getEnvConfig, WhenEnabled } from "../../config/env.config"; interface UploadMainProps { onSubmit: (config: UploadFormData) => void; @@ -29,6 +30,7 @@ export default function UploadMain({ isAuthenticated, onUploadSuccess, }: UploadMainProps) { + const { features } = getEnvConfig(); const [files, setFiles] = useState([]); const [config, setConfig] = useState(DEFAULT_UPLOAD_CONFIG); const [isUploading, setIsUploading] = useState(false); @@ -112,39 +114,44 @@ export default function UploadMain({ className={`config-section ${!isAuthenticated ? "disabled" : ""}`} >
- - - - - Pipeline - - } - value={config.pipeline || "Default"} - onChange={(value) => - setConfig({ - ...config, - pipeline: - value === "Default" ? undefined : (value as Pipeline), - }) - } - options={[ - { label: "Default", value: "Default" }, - { label: "Azure", value: Pipeline.Azure }, - ]} - /> + {features.pipeline && ( + + + + + Pipeline + + } + value={config.pipeline || "Default"} + onChange={(value) => + setConfig({ + ...config, + pipeline: (features.pipeline + ? value === "Default" + ? undefined + : (value as Pipeline) + : undefined) as WhenEnabled<"pipeline", Pipeline>, + }) + } + options={[ + { label: "Default", value: "Default" }, + { label: "Azure", value: Pipeline.Azure }, + ]} + /> + )} diff --git a/apps/web/src/config/env.config.ts b/apps/web/src/config/env.config.ts new file mode 100644 index 000000000..48444458d --- /dev/null +++ b/apps/web/src/config/env.config.ts @@ -0,0 +1,37 @@ +export interface Features { + pipeline: boolean; + // Add new feature flags here + // example: betaFeature: boolean; +} + +export interface EnvConfig { + features: Features; +} + +export const getEnvConfig = (): EnvConfig => { + return { + features: { + pipeline: import.meta.env.VITE_FEATURE_FLAG_PIPELINE === "true", + // Add new feature implementations here + }, + }; +}; + +export function validateEnvConfig(): void { + const requiredFlags: Array = ["pipeline"]; + + for (const flag of requiredFlags) { + const value = import.meta.env[`VITE_FEATURE_FLAG_${flag.toUpperCase()}`]; + if (value !== "true" && value !== "false") { + throw new Error( + `VITE_FEATURE_FLAG_${flag.toUpperCase()} must be either "true" or "false"`, + ); + } + } +} + +// Type helper for feature-guarded types +export type WhenEnabled< + Flag extends keyof Features, + T, +> = Features[Flag] extends true ? T | undefined : undefined; diff --git a/apps/web/src/models/taskConfig.model.ts b/apps/web/src/models/taskConfig.model.ts index 5d0bd9826..ef473d0aa 100644 --- a/apps/web/src/models/taskConfig.model.ts +++ b/apps/web/src/models/taskConfig.model.ts @@ -133,6 +133,8 @@ export interface JsonSchema { schema_type?: string; } +import { WhenEnabled } from "../config/env.config"; + export interface UploadFormData { /** Optional chunk processing configuration */ chunk_processing?: ChunkProcessing; @@ -180,7 +182,7 @@ export interface UploadFormData { target_chunk_length?: number; /** Pipeline to run after processing */ - pipeline?: Pipeline; + pipeline?: WhenEnabled<"pipeline", Pipeline>; } export enum Pipeline { @@ -234,5 +236,5 @@ export const DEFAULT_UPLOAD_CONFIG: UploadFormData = { segment_processing: DEFAULT_SEGMENT_PROCESSING, json_schema: undefined, // or some default schema if needed file: new File([], ""), - pipeline: undefined, // Default pipeline + pipeline: undefined as WhenEnabled<"pipeline", Pipeline>, // Default pipeline }; diff --git a/apps/web/src/models/upload.model.ts b/apps/web/src/models/upload.model.ts index ef9b25371..2782bd36c 100644 --- a/apps/web/src/models/upload.model.ts +++ b/apps/web/src/models/upload.model.ts @@ -6,6 +6,7 @@ import { SegmentationStrategy, Pipeline, } from "./taskConfig.model"; +import { WhenEnabled } from "../config/env.config"; export interface UploadForm { /** The file to be uploaded */ @@ -33,5 +34,5 @@ export interface UploadForm { segmentation_strategy?: SegmentationStrategy; /** Pipeline to run after processing */ - pipeline?: Pipeline; + pipeline?: WhenEnabled<"pipeline", Pipeline>; } diff --git a/clients/node-client/jest.load.config.js b/clients/node-client/jest.load.config.js new file mode 100644 index 000000000..4919091db --- /dev/null +++ b/clients/node-client/jest.load.config.js @@ -0,0 +1,6 @@ +module.exports = { + preset: "ts-jest", + testEnvironment: "node", + testMatch: ["**/__tests__/**/*.load.test.ts"], + testTimeout: 300000, // 5 minute timeout for load tests +}; From ee6fe982478950576fea20ab4222a2589fca7148 Mon Sep 17 00:00:00 2001 From: Akhilesh Sharma Date: Tue, 21 Jan 2025 12:46:49 -0800 Subject: [PATCH 7/7] segments in alphabetical order --- apps/web/.env.example | 4 +-- .../src/components/Upload/ConfigControls.tsx | 21 ++++++++------- apps/web/src/components/Upload/UploadMain.tsx | 1 + apps/web/src/models/taskConfig.model.ts | 4 +-- compose.yaml | 2 +- core/src/models/chunkr/pipeline.rs | 4 ++- core/src/pipeline/mod.rs | 1 - core/src/pipeline/update_metadata.rs | 26 ------------------- core/src/workers/task.rs | 5 +--- services/doctr/main.py | 2 +- 10 files changed, 21 insertions(+), 49 deletions(-) delete mode 100644 core/src/pipeline/update_metadata.rs diff --git a/apps/web/.env.example b/apps/web/.env.example index dfc944906..f139c085f 100644 --- a/apps/web/.env.example +++ b/apps/web/.env.example @@ -5,6 +5,4 @@ VITE_KEYCLOAK_URL= VITE_KEYCLOAK_REALM= VITE_KEYCLOAK_CLIENT_ID= VITE_KEYCLOAK_REDIRECT_URI=http://localhost:5173 -VITE_KEYCLOAK_POST_LOGOUT_REDIRECT_URI=http://localhost:5173 - -VITE_FEATURE_FLAG_PIPELINE=false # true enables Azure Document Intelligence layout analysis, OCR and segment processing heuristics +VITE_KEYCLOAK_POST_LOGOUT_REDIRECT_URI=http://localhost:5173 \ No newline at end of file diff --git a/apps/web/src/components/Upload/ConfigControls.tsx b/apps/web/src/components/Upload/ConfigControls.tsx index 70d539698..301506fcc 100644 --- a/apps/web/src/components/Upload/ConfigControls.tsx +++ b/apps/web/src/components/Upload/ConfigControls.tsx @@ -171,21 +171,23 @@ export function SegmentProcessingControls({ onChange, showOnlyPage = false, }: SegmentProcessingControlsProps) { - const [selectedType, setSelectedType] = - useState("Text"); - const [isDropdownOpen, setIsDropdownOpen] = useState(false); const segmentTypes = showOnlyPage ? (["Page"] as (keyof SegmentProcessing)[]) - : (Object.keys(value).filter( - (key) => key !== "Page" - ) as (keyof SegmentProcessing)[]); + : (Object.keys(value) + .filter((key) => key !== "Page") + .sort() as (keyof SegmentProcessing)[]); + + const defaultSegmentType = segmentTypes[0]; + const [selectedType, setSelectedType] = + useState(defaultSegmentType); + const [isDropdownOpen, setIsDropdownOpen] = useState(false); const dropdownRef = useRef(null); useEffect(() => { if (showOnlyPage && selectedType !== "Page") { setSelectedType("Page"); } else if (!showOnlyPage && selectedType === "Page") { - setSelectedType("Text"); // or any other default segment type + setSelectedType(defaultSegmentType); } }, [selectedType, showOnlyPage]); @@ -270,9 +272,8 @@ export function SegmentProcessingControls({ {segmentTypes.map((type) => (