From c5f8f61bdf00cde489a2eb0a71e85d4939587285 Mon Sep 17 00:00:00 2001 From: Bala FA Date: Mon, 5 Sep 2022 08:41:02 +0530 Subject: [PATCH] Add compose_object() API (#20) Signed-off-by: Bala.FA --- Cargo.toml | 1 + src/s3/args.rs | 346 ++++++++++++++++++++++++++++++++++++- src/s3/client.rs | 419 ++++++++++++++++++++++++++++++++++++++++++++- src/s3/error.rs | 22 ++- src/s3/response.rs | 6 + src/s3/types.rs | 25 +++ src/s3/utils.rs | 43 +++-- tests/tests.rs | 106 ++++++++++++ 8 files changed, 935 insertions(+), 33 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 3ef0721..ccd89d5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,6 +30,7 @@ rand = "0.8.5" serde = { version = "1.0.143", features = ["derive"] } serde_json = "1.0.83" async-std = { version = "1.12.0", features = ["attributes", "tokio1"] } +async-recursion = "1.0.0" [dependencies.reqwest] version = "0.11.11" diff --git a/src/s3/args.rs b/src/s3/args.rs index beb6863..9407e5d 100644 --- a/src/s3/args.rs +++ b/src/s3/args.rs @@ -15,16 +15,18 @@ use crate::s3::error::Error; use crate::s3::sse::{Sse, SseCustomerKey}; -use crate::s3::types::{DeleteObject, Item, NotificationRecords, Part, Retention, SelectRequest}; +use crate::s3::types::{ + DeleteObject, Directive, Item, NotificationRecords, Part, Retention, SelectRequest, +}; use crate::s3::utils::{ check_bucket_name, merge, to_http_header_value, to_iso8601utc, urlencode, Multimap, UtcTime, }; use derivative::Derivative; -const MIN_PART_SIZE: usize = 5_242_880; // 5 MiB -const MAX_PART_SIZE: usize = 5_368_709_120; // 5 GiB -const MAX_OBJECT_SIZE: usize = 5_497_558_138_880; // 5 TiB -const MAX_MULTIPART_COUNT: u16 = 10_000; +pub const MIN_PART_SIZE: usize = 5_242_880; // 5 MiB +pub const MAX_PART_SIZE: usize = 5_368_709_120; // 5 GiB +pub const MAX_OBJECT_SIZE: usize = 5_497_558_138_880; // 5 TiB +pub const MAX_MULTIPART_COUNT: u16 = 10_000; fn object_write_args_headers( extra_headers: Option<&Multimap>, @@ -639,7 +641,7 @@ impl<'a> ObjectConditionalReadArgs<'a> { }) } - pub fn get_headers(&self) -> Multimap { + fn get_range_value(&self) -> String { let (offset, length) = match self.length { Some(_) => (Some(self.offset.unwrap_or(0_usize)), self.length), None => (self.offset, None), @@ -655,7 +657,13 @@ impl<'a> ObjectConditionalReadArgs<'a> { } } + return range; + } + + pub fn get_headers(&self) -> Multimap { let mut headers = Multimap::new(); + + let range = self.get_range_value(); if !range.is_empty() { headers.insert(String::from("Range"), range.clone()); } @@ -696,6 +704,11 @@ impl<'a> ObjectConditionalReadArgs<'a> { } headers.insert(String::from("x-amz-copy-source"), copy_source.to_string()); + let range = self.get_range_value(); + if !range.is_empty() { + headers.insert(String::from("x-amz-copy-source-range"), range.clone()); + } + if let Some(v) = self.match_etag { headers.insert(String::from("x-amz-copy-source-if-match"), v.to_string()); } @@ -733,6 +746,8 @@ pub type GetObjectArgs<'a> = ObjectConditionalReadArgs<'a>; pub type StatObjectArgs<'a> = ObjectConditionalReadArgs<'a>; +pub type CopySource<'a> = ObjectConditionalReadArgs<'a>; + #[derive(Derivative, Clone, Debug, Default)] pub struct RemoveObjectsApiArgs<'a> { pub extra_headers: Option<&'a Multimap>, @@ -1010,3 +1025,322 @@ impl<'a> ListenBucketNotificationArgs<'a> { }) } } + +#[derive(Clone, Debug, Default)] +pub struct UploadPartCopyArgs<'a> { + pub extra_headers: Option<&'a Multimap>, + pub extra_query_params: Option<&'a Multimap>, + pub region: Option<&'a str>, + pub bucket: &'a str, + pub object: &'a str, + pub upload_id: &'a str, + pub part_number: u16, + pub headers: Multimap, +} + +impl<'a> UploadPartCopyArgs<'a> { + pub fn new( + bucket_name: &'a str, + object_name: &'a str, + upload_id: &'a str, + part_number: u16, + headers: Multimap, + ) -> Result, Error> { + check_bucket_name(bucket_name, true)?; + + if object_name.is_empty() { + return Err(Error::InvalidObjectName(String::from( + "object name cannot be empty", + ))); + } + + if upload_id.is_empty() { + return Err(Error::InvalidUploadId(String::from( + "upload ID cannot be empty", + ))); + } + + if part_number < 1 || part_number > 10000 { + return Err(Error::InvalidPartNumber(String::from( + "part number must be between 1 and 1000", + ))); + } + + Ok(UploadPartCopyArgs { + extra_headers: None, + extra_query_params: None, + region: None, + bucket: bucket_name, + object: object_name, + upload_id: upload_id, + part_number: part_number, + headers: headers, + }) + } +} + +#[derive(Clone, Debug, Default)] +pub struct CopyObjectArgs<'a> { + pub extra_headers: Option<&'a Multimap>, + pub extra_query_params: Option<&'a Multimap>, + pub region: Option<&'a str>, + pub bucket: &'a str, + pub object: &'a str, + pub headers: Option<&'a Multimap>, + pub user_metadata: Option<&'a Multimap>, + pub sse: Option<&'a dyn Sse>, + pub tags: Option<&'a std::collections::HashMap>, + pub retention: Option<&'a Retention>, + pub legal_hold: bool, + pub source: CopySource<'a>, + pub metadata_directive: Option, + pub tagging_directive: Option, +} + +impl<'a> CopyObjectArgs<'a> { + pub fn new( + bucket_name: &'a str, + object_name: &'a str, + source: CopySource<'a>, + ) -> Result, Error> { + check_bucket_name(bucket_name, true)?; + + if object_name.is_empty() { + return Err(Error::InvalidObjectName(String::from( + "object name cannot be empty", + ))); + } + + Ok(CopyObjectArgs { + extra_headers: None, + extra_query_params: None, + region: None, + bucket: bucket_name, + object: object_name, + headers: None, + user_metadata: None, + sse: None, + tags: None, + retention: None, + legal_hold: false, + source: source, + metadata_directive: None, + tagging_directive: None, + }) + } + + pub fn get_headers(&self) -> Multimap { + object_write_args_headers( + self.extra_headers, + self.headers, + self.user_metadata, + self.sse, + self.tags, + self.retention, + self.legal_hold, + ) + } +} + +#[derive(Clone, Debug, Default)] +pub struct ComposeSource<'a> { + pub extra_headers: Option<&'a Multimap>, + pub extra_query_params: Option<&'a Multimap>, + pub region: Option<&'a str>, + pub bucket: &'a str, + pub object: &'a str, + pub version_id: Option<&'a str>, + pub ssec: Option<&'a SseCustomerKey>, + pub offset: Option, + pub length: Option, + pub match_etag: Option<&'a str>, + pub not_match_etag: Option<&'a str>, + pub modified_since: Option, + pub unmodified_since: Option, + + object_size: Option, // populated by build_headers() + headers: Option, // populated by build_headers() +} + +impl<'a> ComposeSource<'a> { + pub fn new(bucket_name: &'a str, object_name: &'a str) -> Result, Error> { + check_bucket_name(bucket_name, true)?; + + if object_name.is_empty() { + return Err(Error::InvalidObjectName(String::from( + "object name cannot be empty", + ))); + } + + Ok(ComposeSource { + extra_headers: None, + extra_query_params: None, + region: None, + bucket: bucket_name, + object: object_name, + version_id: None, + ssec: None, + offset: None, + length: None, + match_etag: None, + not_match_etag: None, + modified_since: None, + unmodified_since: None, + object_size: None, + headers: None, + }) + } + + pub fn get_object_size(&self) -> usize { + return self.object_size.expect("ABORT: ComposeSource::build_headers() must be called prior to this method invocation. This shoud not happen."); + } + + pub fn get_headers(&self) -> Multimap { + return self.headers.as_ref().expect("ABORT: ComposeSource::build_headers() must be called prior to this method invocation. This shoud not happen.").clone(); + } + + pub fn build_headers(&mut self, object_size: usize, etag: String) -> Result<(), Error> { + if let Some(v) = self.offset { + if v >= object_size { + return Err(Error::InvalidComposeSourceOffset( + self.bucket.to_string(), + self.object.to_string(), + self.version_id.map(|v| v.to_string()), + v, + object_size, + )); + } + } + + if let Some(v) = self.length { + if v > object_size { + return Err(Error::InvalidComposeSourceLength( + self.bucket.to_string(), + self.object.to_string(), + self.version_id.map(|v| v.to_string()), + v, + object_size, + )); + } + + if (self.offset.unwrap_or_default() + v) > object_size { + return Err(Error::InvalidComposeSourceSize( + self.bucket.to_string(), + self.object.to_string(), + self.version_id.map(|v| v.to_string()), + self.offset.unwrap_or_default() + v, + object_size, + )); + } + } + + self.object_size = Some(object_size); + + let mut headers = Multimap::new(); + + let mut copy_source = String::from("/"); + copy_source.push_str(self.bucket); + copy_source.push_str("/"); + copy_source.push_str(self.object); + if let Some(v) = self.version_id { + copy_source.push_str("?versionId="); + copy_source.push_str(&urlencode(v)); + } + headers.insert(String::from("x-amz-copy-source"), copy_source.to_string()); + + if let Some(v) = self.match_etag { + headers.insert(String::from("x-amz-copy-source-if-match"), v.to_string()); + } + + if let Some(v) = self.not_match_etag { + headers.insert( + String::from("x-amz-copy-source-if-none-match"), + v.to_string(), + ); + } + + if let Some(v) = self.modified_since { + headers.insert( + String::from("x-amz-copy-source-if-modified-since"), + to_http_header_value(v), + ); + } + + if let Some(v) = self.unmodified_since { + headers.insert( + String::from("x-amz-copy-source-if-unmodified-since"), + to_http_header_value(v), + ); + } + + if let Some(v) = self.ssec { + merge(&mut headers, &v.copy_headers()); + } + + if !headers.contains_key("x-amz-copy-source-if-match") { + headers.insert(String::from("x-amz-copy-source-if-match"), etag); + } + + self.headers = Some(headers); + + return Ok(()); + } +} + +pub struct ComposeObjectArgs<'a> { + pub extra_headers: Option<&'a Multimap>, + pub extra_query_params: Option<&'a Multimap>, + pub region: Option<&'a str>, + pub bucket: &'a str, + pub object: &'a str, + pub headers: Option<&'a Multimap>, + pub user_metadata: Option<&'a Multimap>, + pub sse: Option<&'a dyn Sse>, + pub tags: Option<&'a std::collections::HashMap>, + pub retention: Option<&'a Retention>, + pub legal_hold: bool, + pub sources: &'a mut Vec>, +} + +impl<'a> ComposeObjectArgs<'a> { + pub fn new( + bucket_name: &'a str, + object_name: &'a str, + sources: &'a mut Vec>, + ) -> Result, Error> { + check_bucket_name(bucket_name, true)?; + + if object_name.is_empty() { + return Err(Error::InvalidObjectName(String::from( + "object name cannot be empty", + ))); + } + + Ok(ComposeObjectArgs { + extra_headers: None, + extra_query_params: None, + region: None, + bucket: bucket_name, + object: object_name, + headers: None, + user_metadata: None, + sse: None, + tags: None, + retention: None, + legal_hold: false, + sources: sources, + }) + } + + pub fn get_headers(&self) -> Multimap { + object_write_args_headers( + self.extra_headers, + self.headers, + self.user_metadata, + self.sse, + self.tags, + self.retention, + self.legal_hold, + ) + } +} diff --git a/src/s3/client.rs b/src/s3/client.rs index 838f19f..3d9eb46 100644 --- a/src/s3/client.rs +++ b/src/s3/client.rs @@ -20,11 +20,12 @@ use crate::s3::http::{BaseUrl, Url}; use crate::s3::response::*; use crate::s3::signer::sign_v4_s3; use crate::s3::sse::SseCustomerKey; -use crate::s3::types::{Bucket, DeleteObject, Item, NotificationRecords, Part}; +use crate::s3::types::{Bucket, DeleteObject, Directive, Item, NotificationRecords, Part}; use crate::s3::utils::{ from_iso8601utc, get_default_text, get_option_text, get_text, md5sum_hash, merge, sha256_hash, to_amz_date, urldecode, utc_now, Multimap, }; +use async_recursion::async_recursion; use bytes::{Buf, Bytes}; use dashmap::DashMap; use hyper::http::Method; @@ -610,7 +611,11 @@ impl<'a> Client<'a> { let body = resp.bytes().await?; let root = Element::parse(body.reader())?; - let location = root.get_text().unwrap_or_default().to_string(); + let mut location = root.get_text().unwrap_or_default().to_string(); + if location.is_empty() { + location = String::from("us-east-1"); + } + self.region_map .insert(bucket_name.to_string(), location.clone()); Ok(location) @@ -763,6 +768,364 @@ impl<'a> Client<'a> { }) } + async fn calculate_part_count( + &self, + sources: &'a mut Vec>, + ) -> Result { + let mut object_size = 0_usize; + let mut i = 0; + let mut part_count = 0_u16; + + let sources_len = sources.len(); + for source in sources.iter_mut() { + if source.ssec.is_some() && !self.base_url.https { + return Err(Error::SseTlsRequired(Some(format!( + "source {}/{}{}: ", + source.bucket, + source.object, + source + .version_id + .as_ref() + .map_or(String::new(), |v| String::from("?versionId=") + v) + )))); + } + + i += 1; + + let mut stat_args = StatObjectArgs::new(source.bucket, source.object)?; + stat_args.extra_headers = source.extra_headers; + stat_args.extra_query_params = source.extra_query_params; + stat_args.region = source.region; + stat_args.version_id = source.version_id; + stat_args.ssec = source.ssec; + stat_args.match_etag = source.match_etag; + stat_args.not_match_etag = source.not_match_etag; + stat_args.modified_since = source.modified_since; + stat_args.unmodified_since = source.unmodified_since; + + let stat_resp = self.stat_object(&stat_args).await?; + source.build_headers(stat_resp.size, stat_resp.etag.clone())?; + + let mut size = stat_resp.size; + if let Some(l) = source.length { + size = l; + } else if let Some(o) = source.offset { + size -= o; + } + + if size < MIN_PART_SIZE && sources_len != 1 && i != sources_len { + return Err(Error::InvalidComposeSourcePartSize( + source.bucket.to_string(), + source.object.to_string(), + source.version_id.map(|v| v.to_string()), + size, + MIN_PART_SIZE, + )); + } + + object_size += size; + if object_size > MAX_OBJECT_SIZE { + return Err(Error::InvalidObjectSize(object_size)); + } + + if size > MAX_PART_SIZE { + let mut count = size / MAX_PART_SIZE; + let mut last_part_size = size - (count * MAX_PART_SIZE); + if last_part_size > 0 { + count += 1; + } else { + last_part_size = MAX_PART_SIZE; + } + + if last_part_size < MIN_PART_SIZE && sources_len != 1 && i != sources_len { + return Err(Error::InvalidComposeSourceMultipart( + source.bucket.to_string(), + source.object.to_string(), + source.version_id.map(|v| v.to_string()), + size, + MIN_PART_SIZE, + )); + } + + part_count += count as u16; + } else { + part_count += 1; + } + + if part_count > MAX_MULTIPART_COUNT { + return Err(Error::InvalidMultipartCount(MAX_MULTIPART_COUNT)); + } + } + + return Ok(part_count); + } + + #[async_recursion(?Send)] + pub async fn do_compose_object( + &self, + args: &mut ComposeObjectArgs<'_>, + upload_id: &mut String, + ) -> Result { + let part_count = self.calculate_part_count(&mut args.sources).await?; + + if part_count == 1 && args.sources[0].offset.is_none() && args.sources[0].length.is_none() { + let mut source = + ObjectConditionalReadArgs::new(args.sources[0].bucket, args.sources[0].object)?; + source.extra_headers = args.sources[0].extra_headers; + source.extra_query_params = args.sources[0].extra_query_params; + source.region = args.sources[0].region; + source.version_id = args.sources[0].version_id; + source.ssec = args.sources[0].ssec; + source.match_etag = args.sources[0].match_etag; + source.not_match_etag = args.sources[0].not_match_etag; + source.modified_since = args.sources[0].modified_since; + source.unmodified_since = args.sources[0].unmodified_since; + + let mut coargs = CopyObjectArgs::new(args.bucket, args.object, source)?; + coargs.extra_headers = args.extra_headers; + coargs.extra_query_params = args.extra_query_params; + coargs.region = args.region; + coargs.headers = args.headers; + coargs.user_metadata = args.user_metadata; + coargs.sse = args.sse; + coargs.tags = args.tags; + coargs.retention = args.retention; + coargs.legal_hold = args.legal_hold; + + return self.copy_object(&coargs).await; + } + + let headers = args.get_headers(); + + let mut cmu_args = CreateMultipartUploadArgs::new(args.bucket, args.object)?; + cmu_args.extra_query_params = args.extra_query_params; + cmu_args.region = args.region; + cmu_args.headers = Some(&headers); + let resp = self.create_multipart_upload(&cmu_args).await?; + upload_id.push_str(&resp.upload_id); + + let mut part_number = 0_u16; + let ssec_headers = match args.sse { + Some(v) => match v.as_any().downcast_ref::() { + Some(_) => v.headers(), + _ => Multimap::new(), + }, + _ => Multimap::new(), + }; + + let mut parts: Vec = Vec::new(); + for source in args.sources.iter() { + let mut size = source.get_object_size(); + if let Some(l) = source.length { + size = l; + } else if let Some(o) = source.offset { + size -= o; + } + + let mut offset = source.offset.unwrap_or_default(); + + let mut headers = source.get_headers(); + merge(&mut headers, &ssec_headers); + + if size <= MAX_PART_SIZE { + part_number += 1; + if let Some(l) = source.length { + headers.insert( + String::from("x-amz-copy-source-range"), + format!("bytes={}-{}", offset, offset + l - 1), + ); + } else if source.offset.is_some() { + headers.insert( + String::from("x-amz-copy-source-range"), + format!("bytes={}-{}", offset, offset + size - 1), + ); + } + + let mut upc_args = UploadPartCopyArgs::new( + args.bucket, + args.object, + upload_id, + part_number, + headers, + )?; + upc_args.region = args.region; + + let resp = self.upload_part_copy(&upc_args).await?; + parts.push(Part { + number: part_number, + etag: resp.etag, + }); + } else { + while size > 0 { + part_number += 1; + + let start_bytes = offset; + let mut end_bytes = start_bytes + MAX_PART_SIZE; + if size < MAX_PART_SIZE { + end_bytes = start_bytes + size; + } + + let mut headers_copy = headers.clone(); + headers_copy.insert( + String::from("x-amz-copy-source-range"), + format!("bytes={}-{}", start_bytes, end_bytes), + ); + + let mut upc_args = UploadPartCopyArgs::new( + args.bucket, + args.object, + upload_id, + part_number, + headers_copy, + )?; + upc_args.region = args.region; + + let resp = self.upload_part_copy(&upc_args).await?; + parts.push(Part { + number: part_number, + etag: resp.etag, + }); + + offset = start_bytes; + size -= end_bytes - start_bytes; + } + } + } + + let mut cmu_args = + CompleteMultipartUploadArgs::new(args.bucket, args.object, upload_id, &parts)?; + cmu_args.region = args.region; + return self.complete_multipart_upload(&cmu_args).await; + } + + pub async fn compose_object( + &self, + args: &mut ComposeObjectArgs<'_>, + ) -> Result { + if let Some(v) = &args.sse { + if v.tls_required() && !self.base_url.https { + return Err(Error::SseTlsRequired(None)); + } + } + + let mut upload_id = String::new(); + let res = self.do_compose_object(args, &mut upload_id).await; + if res.is_err() && !upload_id.is_empty() { + let amuargs = &AbortMultipartUploadArgs::new(&args.bucket, &args.object, &upload_id)?; + self.abort_multipart_upload(&amuargs).await?; + } + + return res; + } + + pub async fn copy_object( + &self, + args: &CopyObjectArgs<'_>, + ) -> Result { + if let Some(v) = &args.sse { + if v.tls_required() && !self.base_url.https { + return Err(Error::SseTlsRequired(None)); + } + } + + if args.source.ssec.is_some() && !self.base_url.https { + return Err(Error::SseTlsRequired(None)); + } + + let stat_resp = self.stat_object(&args.source).await?; + + if args.source.offset.is_some() + || args.source.length.is_some() + || stat_resp.size > MAX_PART_SIZE + { + if let Some(v) = &args.metadata_directive { + match v { + Directive::Copy => return Err(Error::InvalidCopyDirective(String::from("COPY metadata directive is not applicable to source object size greater than 5 GiB"))), + _ => todo!(), // Nothing to do. + } + } + + if let Some(v) = &args.tagging_directive { + match v { + Directive::Copy => return Err(Error::InvalidCopyDirective(String::from("COPY tagging directive is not applicable to source object size greater than 5 GiB"))), + _ => todo!(), // Nothing to do. + } + } + + let mut src = ComposeSource::new(args.source.bucket, args.source.object)?; + src.extra_headers = args.source.extra_headers; + src.extra_query_params = args.source.extra_query_params; + src.region = args.source.region; + src.ssec = args.source.ssec; + src.offset = args.source.offset; + src.length = args.source.length; + src.match_etag = args.source.match_etag; + src.not_match_etag = args.source.not_match_etag; + src.modified_since = args.source.modified_since; + src.unmodified_since = args.source.unmodified_since; + + let mut sources: Vec = Vec::new(); + sources.push(src); + + let mut coargs = ComposeObjectArgs::new(args.bucket, args.object, &mut sources)?; + coargs.extra_headers = args.extra_headers; + coargs.extra_query_params = args.extra_query_params; + coargs.region = args.region; + coargs.headers = args.headers; + coargs.user_metadata = args.user_metadata; + coargs.sse = args.sse; + coargs.tags = args.tags; + coargs.retention = args.retention; + coargs.legal_hold = args.legal_hold; + + return self.compose_object(&mut coargs).await; + } + + let mut headers = args.get_headers(); + if let Some(v) = &args.metadata_directive { + headers.insert(String::from("x-amz-metadata-directive"), v.to_string()); + } + if let Some(v) = &args.tagging_directive { + headers.insert(String::from("x-amz-tagging-directive"), v.to_string()); + } + merge(&mut headers, &args.source.get_copy_headers()); + + let mut query_params = Multimap::new(); + if let Some(v) = &args.extra_query_params { + merge(&mut query_params, v); + } + + let region = self.get_region(&args.bucket, args.region).await?; + + let resp = self + .execute( + Method::PUT, + ®ion, + &mut headers, + &query_params, + Some(&args.bucket), + Some(&args.object), + None, + ) + .await?; + + let header_map = resp.headers().clone(); + let body = resp.bytes().await?; + let root = Element::parse(body.reader())?; + + Ok(CopyObjectResponse { + headers: header_map.clone(), + bucket_name: args.bucket.to_string(), + object_name: args.object.to_string(), + location: region.clone(), + etag: get_text(&root, "ETag")?.trim_matches('"').to_string(), + version_id: match header_map.get("x-amz-version-id") { + Some(v) => Some(v.to_str()?.to_string()), + None => None, + }, + }) + } + pub async fn create_multipart_upload( &self, args: &CreateMultipartUploadArgs<'_>, @@ -839,7 +1202,7 @@ impl<'a> Client<'a> { pub async fn get_object(&self, args: &GetObjectArgs<'_>) -> Result { if args.ssec.is_some() && !self.base_url.https { - return Err(Error::SseTlsRequired); + return Err(Error::SseTlsRequired(None)); } let region = self.get_region(&args.bucket, args.region).await?; @@ -1543,7 +1906,7 @@ impl<'a> Client<'a> { ) -> Result { if let Some(v) = &args.sse { if v.tls_required() && !self.base_url.https { - return Err(Error::SseTlsRequired); + return Err(Error::SseTlsRequired(None)); } } @@ -1842,7 +2205,7 @@ impl<'a> Client<'a> { args: &SelectObjectContentArgs<'_>, ) -> Result { if args.ssec.is_some() && !self.base_url.https { - return Err(Error::SseTlsRequired); + return Err(Error::SseTlsRequired(None)); } let region = self.get_region(&args.bucket, args.region).await?; @@ -1885,7 +2248,7 @@ impl<'a> Client<'a> { args: &StatObjectArgs<'_>, ) -> Result { if args.ssec.is_some() && !self.base_url.https { - return Err(Error::SseTlsRequired); + return Err(Error::SseTlsRequired(None)); } let region = self.get_region(&args.bucket, args.region).await?; @@ -1943,5 +2306,47 @@ impl<'a> Client<'a> { self.put_object_api(&poa_args).await } - // UploadPartCopyResponse UploadPartCopy(UploadPartCopyArgs args); + pub async fn upload_part_copy( + &self, + args: &UploadPartCopyArgs<'_>, + ) -> Result { + let region = self.get_region(&args.bucket, args.region).await?; + + let mut headers = Multimap::new(); + if let Some(v) = &args.extra_headers { + merge(&mut headers, v); + } + merge(&mut headers, &args.headers); + + let mut query_params = Multimap::new(); + if let Some(v) = &args.extra_query_params { + merge(&mut query_params, v); + } + query_params.insert(String::from("partNumber"), args.part_number.to_string()); + query_params.insert(String::from("uploadId"), args.upload_id.to_string()); + + let resp = self + .execute( + Method::PUT, + ®ion, + &mut headers, + &query_params, + Some(&args.bucket), + Some(&args.object), + None, + ) + .await?; + let header_map = resp.headers().clone(); + let body = resp.bytes().await?; + let root = Element::parse(body.reader())?; + + Ok(PutObjectBaseResponse { + headers: header_map.clone(), + bucket_name: args.bucket.to_string(), + object_name: args.object.to_string(), + location: region.clone(), + etag: get_text(&root, "ETag")?.trim_matches('"').to_string(), + version_id: None, + }) + } } diff --git a/src/s3/error.rs b/src/s3/error.rs index f3d249b..71184c9 100644 --- a/src/s3/error.rs +++ b/src/s3/error.rs @@ -79,7 +79,7 @@ pub enum Error { InvalidObjectSize(usize), MissingPartSize, InvalidPartCount(usize, usize, u16), - SseTlsRequired, + SseTlsRequired(Option), InsufficientData(usize, usize), InvalidLegalHold(String), InvalidSelectExpression(String), @@ -88,6 +88,15 @@ pub enum Error { UnknownEventType(String), SelectError(String, String), UnsupportedApi(String), + InvalidComposeSource(String), + InvalidComposeSourceOffset(String, String, Option, usize, usize), + InvalidComposeSourceLength(String, String, Option, usize, usize), + InvalidComposeSourceSize(String, String, Option, usize, usize), + InvalidComposeSourcePartSize(String, String, Option, usize, usize), + InvalidComposeSourceMultipart(String, String, Option, usize, usize), + InvalidDirective(String), + InvalidCopyDirective(String), + InvalidMultipartCount(u16), } impl std::error::Error for Error {} @@ -117,7 +126,7 @@ impl fmt::Display for Error { Error::InvalidObjectSize(s) => write!(f, "object size {} is not supported; maximum allowed 5TiB", s), Error::MissingPartSize => write!(f, "valid part size must be provided when object size is unknown"), Error::InvalidPartCount(os, ps, pc) => write!(f, "object size {} and part size {} make more than {} parts for upload", os, ps, pc), - Error::SseTlsRequired => write!(f, "SSE operation must be performed over a secure connection"), + Error::SseTlsRequired(m) => write!(f, "{}SSE operation must be performed over a secure connection", m.as_ref().map_or(String::new(), |v| v.clone())), Error::InsufficientData(ps, br) => write!(f, "not enough data in the stream; expected: {}, got: {} bytes", ps, br), Error::InvalidBaseUrl(m) => write!(f, "{}", m), Error::UrlBuildError(m) => write!(f, "{}", m), @@ -132,6 +141,15 @@ impl fmt::Display for Error { Error::UnknownEventType(et) => write!(f, "unknown event type {}", et), Error::SelectError(ec, em) => write!(f, "error code: {}, error message: {}", ec, em), Error::UnsupportedApi(a) => write!(f, "{} API is not supported in Amazon AWS S3", a), + Error::InvalidComposeSource(m) => write!(f, "{}", m), + Error::InvalidComposeSourceOffset(b, o, v, of, os) => write!(f, "source {}/{}{}: offset {} is beyond object size {}", b, o, v.as_ref().map_or(String::new(), |v| String::from("?versionId=") + v), of, os), + Error::InvalidComposeSourceLength(b, o, v, l, os) => write!(f, "source {}/{}{}: length {} is beyond object size {}", b, o, v.as_ref().map_or(String::new(), |v| String::from("?versionId=") + v), l, os), + Error::InvalidComposeSourceSize(b, o, v, cs, os) => write!(f, "source {}/{}{}: compose size {} is beyond object size {}", b, o, v.as_ref().map_or(String::new(), |v| String::from("?versionId=") + v), cs, os), + Error::InvalidDirective(m) => write!(f, "{}", m), + Error::InvalidCopyDirective(m) => write!(f, "{}", m), + Error::InvalidComposeSourcePartSize(b, o, v, s, es) => write!(f, "source {}/{}{}: size {} must be greater than {}", b, o, v.as_ref().map_or(String::new(), |v| String::from("?versionId=") + v), s, es), + Error::InvalidComposeSourceMultipart(b, o, v, s, es) => write!(f, "source {}/{}{}: size {} for multipart split upload of {}, last part size is less than {}", b, o, v.as_ref().map_or(String::new(), |v| String::from("?versionId=") + v), s, s, es), + Error::InvalidMultipartCount(c) => write!(f, "Compose sources create more than allowed multipart count {}", c), } } } diff --git a/src/s3/response.rs b/src/s3/response.rs index 2663785..cf93eb6 100644 --- a/src/s3/response.rs +++ b/src/s3/response.rs @@ -83,6 +83,12 @@ pub type UploadPartResponse = PutObjectApiResponse; pub type PutObjectResponse = PutObjectApiResponse; +pub type UploadPartCopyResponse = PutObjectApiResponse; + +pub type CopyObjectResponse = PutObjectApiResponse; + +pub type ComposeObjectResponse = PutObjectApiResponse; + #[derive(Debug)] pub struct StatObjectResponse { pub headers: HeaderMap, diff --git a/src/s3/types.rs b/src/s3/types.rs index 3a4c594..557da35 100644 --- a/src/s3/types.rs +++ b/src/s3/types.rs @@ -566,3 +566,28 @@ pub struct NotificationRecords { #[serde(alias = "Records")] pub records: Vec, } + +#[derive(Clone, Debug)] +pub enum Directive { + Copy, + Replace, +} + +impl Directive { + pub fn parse(s: &str) -> Result { + match s { + "COPY" => Ok(Directive::Copy), + "REPLACE" => Ok(Directive::Replace), + _ => Err(Error::InvalidDirective(s.to_string())), + } + } +} + +impl fmt::Display for Directive { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Directive::Copy => write!(f, "COPY"), + Directive::Replace => write!(f, "REPLACE"), + } + } +} diff --git a/src/s3/utils.rs b/src/s3/utils.rs index 898813b..691e1b3 100644 --- a/src/s3/utils.rs +++ b/src/s3/utils.rs @@ -24,6 +24,7 @@ use md5::compute as md5compute; use multimap::MultiMap; use regex::Regex; use sha2::{Digest, Sha256}; +use std::collections::BTreeMap; pub use urlencoding::decode as urldecode; pub use urlencoding::encode as urlencode; use xmltree::Element; @@ -155,41 +156,47 @@ pub fn get_canonical_headers(map: &Multimap) -> (String, String) { lazy_static! { static ref MULTI_SPACE_REGEX: Regex = Regex::new("( +)").unwrap(); } - let mut signed_headers: Vec = Vec::new(); - let mut mmap: MultiMap = MultiMap::new(); + let mut btmap: BTreeMap = BTreeMap::new(); for (k, values) in map.iter_all() { let key = k.to_lowercase(); if "authorization" == key || "user-agent" == key { continue; } - if !signed_headers.contains(&key) { - signed_headers.push(key.clone()); - } - for v in values { - mmap.insert(key.clone(), v.to_string()); - } - } + let mut vs = values.clone(); + vs.sort(); - let mut canonical_headers: Vec = Vec::new(); - for (key, values) in mmap.iter_all_mut() { - values.sort(); let mut value = String::new(); - for v in values { + for v in vs { if !value.is_empty() { value.push_str(","); } - let s: String = MULTI_SPACE_REGEX.replace_all(v, " ").to_string(); + let s: String = MULTI_SPACE_REGEX.replace_all(&v, " ").to_string(); value.push_str(&s); } - canonical_headers.push(key.to_string() + ":" + value.as_str()); + btmap.insert(key.clone(), value.clone()); } - signed_headers.sort(); - canonical_headers.sort(); + let mut signed_headers = String::new(); + let mut canonical_headers = String::new(); + let mut add_delim = false; + for (key, value) in &btmap { + if add_delim { + signed_headers.push_str(";"); + canonical_headers.push_str("\n"); + } - return (signed_headers.join(";"), canonical_headers.join("\n")); + signed_headers.push_str(key); + + canonical_headers.push_str(key); + canonical_headers.push_str(":"); + canonical_headers.push_str(value); + + add_delim = true; + } + + return (signed_headers, canonical_headers); } pub fn check_bucket_name(bucket_name: &str, strict: bool) -> Result<(), Error> { diff --git a/tests/tests.rs b/tests/tests.rs index 40a17de..bb5333f 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -498,6 +498,106 @@ impl<'a> ClientTest<'_> { spawned_task.await; assert_eq!(receiver.recv().await.unwrap(), true); } + + async fn copy_object(&self) { + let src_object_name = rand_object_name(); + + let size = 16_usize; + self.client + .put_object( + &mut PutObjectArgs::new( + &self.test_bucket, + &src_object_name, + &mut RandReader::new(size), + Some(size), + None, + ) + .unwrap(), + ) + .await + .unwrap(); + + let object_name = rand_object_name(); + self.client + .copy_object( + &CopyObjectArgs::new( + &self.test_bucket, + &object_name, + CopySource::new(&self.test_bucket, &src_object_name).unwrap(), + ) + .unwrap(), + ) + .await + .unwrap(); + + let resp = self + .client + .stat_object(&StatObjectArgs::new(&self.test_bucket, &object_name).unwrap()) + .await + .unwrap(); + assert_eq!(resp.size, size); + + self.client + .remove_object(&RemoveObjectArgs::new(&self.test_bucket, &object_name).unwrap()) + .await + .unwrap(); + + self.client + .remove_object(&RemoveObjectArgs::new(&self.test_bucket, &src_object_name).unwrap()) + .await + .unwrap(); + } + + async fn compose_object(&self) { + let src_object_name = rand_object_name(); + + let size = 16_usize; + self.client + .put_object( + &mut PutObjectArgs::new( + &self.test_bucket, + &src_object_name, + &mut RandReader::new(size), + Some(size), + None, + ) + .unwrap(), + ) + .await + .unwrap(); + + let mut s1 = ComposeSource::new(&self.test_bucket, &src_object_name).unwrap(); + s1.offset = Some(3); + s1.length = Some(5); + let mut sources: Vec = Vec::new(); + sources.push(s1); + + let object_name = rand_object_name(); + + self.client + .compose_object( + &mut ComposeObjectArgs::new(&self.test_bucket, &object_name, &mut sources).unwrap(), + ) + .await + .unwrap(); + + let resp = self + .client + .stat_object(&StatObjectArgs::new(&self.test_bucket, &object_name).unwrap()) + .await + .unwrap(); + assert_eq!(resp.size, 5); + + self.client + .remove_object(&RemoveObjectArgs::new(&self.test_bucket, &object_name).unwrap()) + .await + .unwrap(); + + self.client + .remove_object(&RemoveObjectArgs::new(&self.test_bucket, &src_object_name).unwrap()) + .await + .unwrap(); + } } #[tokio::main] @@ -528,6 +628,9 @@ async fn s3_tests() -> Result<(), Box> { ); ctest.init().await; + println!("compose_object()"); + ctest.compose_object().await; + println!("make_bucket() + bucket_exists() + remove_bucket()"); ctest.bucket_exists().await; @@ -555,6 +658,9 @@ async fn s3_tests() -> Result<(), Box> { println!("listen_bucket_notification()"); ctest.listen_bucket_notification().await; + println!("copy_object()"); + ctest.copy_object().await; + ctest.drop().await; Ok(())