Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion fmt/src/document/factory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use anyhow::Context;
use std::collections::HashMap;
use std::collections::HashSet;
use std::fs;
Expand All @@ -21,6 +20,8 @@ use std::path::Path;
use std::path::PathBuf;
use std::time::SystemTime;

use anyhow::Context;

use crate::config::Mapping;
use crate::document::Attributes;
use crate::document::Document;
Expand Down
34 changes: 1 addition & 33 deletions fmt/src/license/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,43 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use anyhow::Context;

use crate::config::Config;

#[derive(Debug, Clone)]
pub struct HeaderSource {
pub content: String,
}

impl HeaderSource {
pub fn from_config(config: &Config) -> anyhow::Result<Self> {
// 1. inline_header takes priority.
if let Some(content) = config.inline_header.as_ref().cloned() {
return Ok(HeaderSource { content });
}

// 2. Then, header_path tries to load from base_dir.
let header_path = config
.header_path
.as_ref()
.context("no header source found (both inline_header and header_path are None)")?;
let path = {
let mut path = config.base_dir.clone();
path.push(header_path);
path
};
if let Ok(content) = std::fs::read_to_string(path) {
return Ok(HeaderSource { content });
}

// 3. Finally, fallback to try bundled headers.
bundled_headers(header_path).with_context(|| {
format!("no header source found (header_path is invalid: {header_path})")
})
}
}

macro_rules! match_bundled_headers {
($name:expr, $($file:expr),*) => {
match $name {
Expand All @@ -60,7 +28,7 @@ macro_rules! match_bundled_headers {
}
}

fn bundled_headers(name: &str) -> Option<HeaderSource> {
pub(crate) fn bundled_headers(name: &str) -> Option<HeaderSource> {
match_bundled_headers!(
name,
"Apache-2.0.txt",
Expand Down
101 changes: 97 additions & 4 deletions fmt/src/processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ use crate::git;
use crate::header::matcher::HeaderMatcher;
use crate::header::model::default_headers;
use crate::header::model::deserialize_header_definitions;
use crate::header::model::HeaderDef;
use crate::license::bundled_headers;
use crate::license::HeaderSource;
use crate::selection::Selection;

Expand Down Expand Up @@ -56,6 +58,10 @@ pub fn check_license_header<C: Callback>(
.with_context(|| format!("cannot parse config file: {name}"))?
};

let config_dir = run_config
.parent()
.context("cannot get parent directory of config file")?;

let basedir = config.base_dir.clone();
anyhow::ensure!(
basedir.is_dir(),
Expand Down Expand Up @@ -107,10 +113,9 @@ pub fn check_license_header<C: Callback>(
}
}
}

for additional_header in &config.additional_headers {
let additional_defs = fs::read_to_string(additional_header)
.with_context(|| format!("cannot load header definitions: {additional_header}"))
.and_then(deserialize_header_definitions)?;
let additional_defs = load_additional_headers(additional_header, &config, config_dir)?;
for (k, v) in additional_defs {
match defs.entry(k) {
Entry::Occupied(mut ent) => {
Expand All @@ -123,11 +128,12 @@ pub fn check_license_header<C: Callback>(
}
}
}

defs
};

let header_matcher = {
let header_source = HeaderSource::from_config(&config)?;
let header_source = load_header_sources(&config, config_dir)?;
HeaderMatcher::new(header_source.content)
};

Expand Down Expand Up @@ -164,3 +170,90 @@ pub fn check_license_header<C: Callback>(

Ok(())
}

fn load_additional_headers(
additional_header: impl AsRef<Path>,
config: &Config,
config_dir: &Path,
) -> anyhow::Result<HashMap<String, HeaderDef>> {
let additional_header = additional_header.as_ref();

// 1. Based on config directory.
let path = {
let mut path = config_dir.to_path_buf();
path.push(additional_header);
path
};
if let Ok(content) = fs::read_to_string(&path) {
return deserialize_header_definitions(content)
.with_context(|| format!("cannot load header definitions: {}", path.display()));
}

// 2. Based on the base_dir.
let path = {
let mut path = config.base_dir.clone();
path.push(additional_header);
path
};
if let Ok(content) = fs::read_to_string(&path) {
return deserialize_header_definitions(content)
.with_context(|| format!("cannot load header definitions: {}", path.display()));
}

// 3. Based on current working directory.
if let Ok(content) = fs::read_to_string(additional_header) {
return deserialize_header_definitions(content).with_context(|| {
format!(
"cannot load header definitions: {}",
additional_header.display()
)
});
}

Err(anyhow::anyhow!(
"cannot find header definitions: {}",
additional_header.display()
))
}

fn load_header_sources(config: &Config, config_dir: &Path) -> anyhow::Result<HeaderSource> {
// 1. inline_header takes priority.
if let Some(content) = config.inline_header.as_ref().cloned() {
return Ok(HeaderSource { content });
}

// 2. Then, try to load from header_path.
let header_path = config
.header_path
.as_ref()
.context("no header source found (both inline_header and header_path are None)")?;

// 2.1 Based on config directory.
let path = {
let mut path = config_dir.to_path_buf();
path.push(header_path);
path
};
if let Ok(content) = fs::read_to_string(path) {
return Ok(HeaderSource { content });
}

// 2.2 Based on the base_dir.
let path = {
let mut path = config.base_dir.clone();
path.push(header_path);
path
};
if let Ok(content) = fs::read_to_string(path) {
return Ok(HeaderSource { content });
}

// 2.3 Based on current working directory.
if let Ok(content) = fs::read_to_string(header_path) {
return Ok(HeaderSource { content });
}

// 3. Finally, fallback to try bundled headers.
bundled_headers(header_path)
.with_context(|| format!("no header source found (header_path is invalid: {header_path})"))
}