Skip to content

Commit 69bc471

Browse files
fix: flush downloaded file before return (#2080)
* chore: extend OCI downloader test coverage * fix: sync data before return
1 parent e355328 commit 69bc471

File tree

3 files changed

+224
-19
lines changed

3 files changed

+224
-19
lines changed

agent-control/src/package/oci/artifact_definitions.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,12 @@ pub mod tests {
161161
use super::*;
162162
use assert_matches::assert_matches;
163163

164+
impl LocalAgentPackage {
165+
pub fn path(&self) -> &PathBuf {
166+
&self.blob_path
167+
}
168+
}
169+
164170
#[rstest::rstest]
165171
#[case::tar_gz_single_layer(
166172
vec![AGENT_PACKAGE_LAYER_TAR_GZ]

agent-control/src/package/oci/downloader.rs

Lines changed: 217 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ impl OCIAgentDownloader for OCIArtifactDownloader {
7272
})
7373
.map_err(|e| {
7474
OCIDownloaderError::DownloadingArtifact(format!(
75-
"Download attempts exceeded. Last error: {e}"
75+
"download attempts exceeded. Last error: {e}"
7676
))
7777
})
7878
}
@@ -147,7 +147,7 @@ impl OCIArtifactDownloader {
147147
.map_err(OCIDownloaderError::OciDistribution)?;
148148

149149
let (layer, media_type) = LocalAgentPackage::get_layer(&image_manifest).map_err(|e| {
150-
OCIDownloaderError::DownloadingArtifact(format!("validating package layer: {e}"))
150+
OCIDownloaderError::DownloadingArtifact(format!("validating package manifest: {e}"))
151151
})?;
152152

153153
let layer_path = package_dir.join(layer.digest.replace(':', "_"));
@@ -160,6 +160,9 @@ impl OCIArtifactDownloader {
160160
.await
161161
.map_err(OCIDownloaderError::OciDistribution)?;
162162

163+
// Ensure all data is flushed to disk before returning
164+
file.sync_data().await.map_err(OCIDownloaderError::Io)?;
165+
163166
debug!("Artifact written to {}", layer_path.display());
164167

165168
Ok(LocalAgentPackage::new(media_type, layer_path))
@@ -213,12 +216,23 @@ fn add_cert<'a>(mut certs: Vec<Certificate>, cert: CertificateDer<'a>) -> Vec<Ce
213216

214217
#[cfg(test)]
215218
pub mod tests {
219+
use crate::package::oci::artifact_definitions::{
220+
LayerMediaType, ManifestArtifactType, PackageMediaType,
221+
};
222+
216223
use super::*;
217224
use assert_matches::assert_matches;
225+
use httpmock::prelude::*;
218226
use mockall::mock;
227+
use oci_client::client::ClientProtocol;
228+
use oci_client::manifest::{OciDescriptor, OciImageManifest};
229+
use oci_spec::distribution::Reference;
230+
use ring::digest::{SHA256, digest};
231+
use serde_json::json;
219232
use std::fs::File;
220233
use std::io::Write;
221234
use std::path::PathBuf;
235+
use std::str::FromStr;
222236
use tempfile::tempdir;
223237

224238
mock! {
@@ -232,18 +246,7 @@ pub mod tests {
232246
}
233247
}
234248

235-
const INVALID_TESTING_CERT: &str =
236-
"-----BEGIN CERTIFICATE-----\ninvalid!\n-----END CERTIFICATE-----";
237-
238-
fn valid_testing_cert() -> String {
239-
let subject_alt_names = vec!["localhost".to_string()];
240-
let rcgen::CertifiedKey {
241-
cert,
242-
signing_key: _,
243-
} = rcgen::generate_simple_self_signed(subject_alt_names).unwrap();
244-
cert.pem()
245-
}
246-
249+
// ========== Proxy Tests ==========
247250
#[test]
248251
fn test_with_empty_proxy_url() {
249252
let proxy_config = ProxyConfig::from_url("".to_string()); // Assuming ProxyConfig::new method exists
@@ -408,4 +411,204 @@ pub mod tests {
408411
let certificates = certs_from_paths(&ca_bundle_file, ca_bundle_dir).unwrap();
409412
assert_eq!(certificates.len(), 1);
410413
}
414+
415+
const INVALID_TESTING_CERT: &str =
416+
"-----BEGIN CERTIFICATE-----\ninvalid!\n-----END CERTIFICATE-----";
417+
418+
fn valid_testing_cert() -> String {
419+
let subject_alt_names = vec!["localhost".to_string()];
420+
let rcgen::CertifiedKey {
421+
cert,
422+
signing_key: _,
423+
} = rcgen::generate_simple_self_signed(subject_alt_names).unwrap();
424+
cert.pem()
425+
}
426+
427+
// ========== Fake OCI server Tests ==========
428+
429+
#[test]
430+
fn test_download_agent_package_success() {
431+
let server = FakeOciServer::new("test-repo", "v1.0.0")
432+
.with_artifact_type(&ManifestArtifactType::AgentPackage.to_string())
433+
.with_layer(
434+
b"test agent package content",
435+
&LayerMediaType::AgentPackage(PackageMediaType::AgentPackageLayerTarGz).to_string(),
436+
)
437+
.build();
438+
439+
let downloader = create_downloader();
440+
let dest_dir = tempdir().unwrap();
441+
let local_agent_package = downloader
442+
.download(&server.reference(), dest_dir.path())
443+
.unwrap();
444+
445+
assert_eq!(
446+
std::fs::read(local_agent_package.path()).unwrap(),
447+
b"test agent package content"
448+
);
449+
}
450+
451+
#[test]
452+
fn test_download_with_multiple_layers() {
453+
let server = FakeOciServer::new("test-repo", "v1.0.0")
454+
.with_artifact_type(&ManifestArtifactType::AgentPackage.to_string())
455+
.with_layer(
456+
b"layer 1 content",
457+
&LayerMediaType::AgentPackage(PackageMediaType::AgentPackageLayerTarGz).to_string(),
458+
)
459+
.with_layer(
460+
b"layer 2 content",
461+
"application/vnd.newrelic.agent.unknown-content.v1",
462+
)
463+
.build();
464+
465+
let downloader = create_downloader();
466+
let dest_dir = tempdir().unwrap();
467+
let local_agent_package = downloader
468+
.download(&server.reference(), dest_dir.path())
469+
.unwrap();
470+
471+
assert_eq!(
472+
std::fs::read(local_agent_package.path()).unwrap(),
473+
b"layer 1 content"
474+
);
475+
}
476+
477+
#[test]
478+
fn test_download_with_invalid_package() {
479+
let server = FakeOciServer::new("test-repo", "v1.0.0")
480+
.with_layer(
481+
b"test content",
482+
&LayerMediaType::AgentPackage(PackageMediaType::AgentPackageLayerTarGz).to_string(),
483+
)
484+
.with_artifact_type("application/vnd.unknown.type.v1")
485+
.build();
486+
487+
let downloader = create_downloader();
488+
let dest_dir = tempdir().unwrap();
489+
let err = downloader
490+
.download(&server.reference(), dest_dir.path())
491+
.unwrap_err();
492+
assert!(err.to_string().contains("validating package manifest"));
493+
}
494+
495+
#[test]
496+
fn test_download_with_missing_manifest() {
497+
let server = MockServer::start();
498+
server.mock(|when, then| {
499+
when.method(GET).path("/v2/test-repo/manifests/v1.0.0");
500+
then.status(404).json_body(json!({
501+
"errors": [{"code": "MANIFEST_UNKNOWN", "message": "manifest unknown"}]
502+
}));
503+
});
504+
505+
let reference =
506+
Reference::from_str(&format!("{}/test-repo:v1.0.0", server.address())).unwrap();
507+
let downloader = create_downloader();
508+
let dest_dir = tempdir().unwrap();
509+
let err = downloader
510+
.download(&reference, dest_dir.path())
511+
.unwrap_err();
512+
assert!(
513+
err.to_string().contains("download attempts exceeded"),
514+
"{}",
515+
err.to_string()
516+
);
517+
}
518+
519+
fn hex_bytes(bytes: &[u8]) -> String {
520+
bytes.iter().map(|b| format!("{:02x}", b)).collect()
521+
}
522+
523+
struct FakeOciServer {
524+
server: MockServer,
525+
repo: String,
526+
tag: String,
527+
layers: Vec<(String, Vec<u8>)>, // (digest, content)
528+
manifest: OciImageManifest,
529+
}
530+
531+
impl FakeOciServer {
532+
fn new(repo: &str, tag: &str) -> Self {
533+
Self {
534+
server: MockServer::start(),
535+
repo: repo.to_string(),
536+
tag: tag.to_string(),
537+
layers: Vec::new(),
538+
manifest: OciImageManifest::default(),
539+
}
540+
}
541+
fn with_artifact_type(mut self, artifact_type: &str) -> Self {
542+
self.manifest.artifact_type = Some(artifact_type.to_string());
543+
self
544+
}
545+
546+
fn with_layer(mut self, content: &[u8], media_type: &str) -> Self {
547+
let digest = digest(&SHA256, content);
548+
let digest_str = format!("sha256:{}", hex_bytes(digest.as_ref()));
549+
self.layers.push((digest_str, content.to_vec()));
550+
551+
let layer_descriptor = OciDescriptor {
552+
media_type: media_type.to_string(),
553+
digest: self.layers.last().unwrap().0.clone(),
554+
size: content.len() as i64,
555+
..Default::default()
556+
};
557+
self.manifest.layers.push(layer_descriptor);
558+
self
559+
}
560+
561+
fn build(self) -> Self {
562+
self.setup_mocks();
563+
self
564+
}
565+
566+
fn setup_mocks(&self) {
567+
// Mock manifest endpoint
568+
let manifest_clone = self.manifest.clone();
569+
self.server.mock(|when, then| {
570+
when.method(GET)
571+
.path(format!("/v2/{}/manifests/{}", self.repo, self.tag));
572+
then.status(200)
573+
.header("Content-Type", "application/vnd.oci.image.manifest.v1+json")
574+
.json_body_obj(&manifest_clone);
575+
});
576+
577+
// Mock blob endpoints
578+
for (digest, content) in &self.layers {
579+
let content_clone = content.clone();
580+
let digest_clone = digest.clone();
581+
self.server.mock(move |when, then| {
582+
when.method(GET)
583+
.path(format!("/v2/{}/blobs/{}", self.repo, digest_clone));
584+
then.status(200)
585+
.header("Content-Type", "application/octet-stream")
586+
.body(&content_clone);
587+
});
588+
}
589+
}
590+
591+
fn reference(&self) -> Reference {
592+
Reference::from_str(&format!(
593+
"{}/{}:{}",
594+
self.server.address(),
595+
self.repo,
596+
self.tag
597+
))
598+
.unwrap()
599+
}
600+
}
601+
602+
fn create_downloader() -> OCIArtifactDownloader {
603+
let runtime = Arc::new(tokio::runtime::Runtime::new().unwrap());
604+
OCIArtifactDownloader::try_new(
605+
ProxyConfig::default(),
606+
runtime,
607+
ClientConfig {
608+
protocol: ClientProtocol::Http,
609+
..Default::default()
610+
},
611+
)
612+
.unwrap()
613+
}
411614
}

agent-control/tests/on_host/tools/oci_artifact.rs

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -112,9 +112,5 @@ pub fn push_agent_package(
112112
}
113113

114114
fn hex_bytes(bytes: &[u8]) -> String {
115-
let mut hex_string = String::new();
116-
for byte in bytes {
117-
hex_string.push_str(&format!("{:02x}", byte));
118-
}
119-
hex_string
115+
bytes.iter().map(|b| format!("{:02x}", b)).collect()
120116
}

0 commit comments

Comments
 (0)