Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
283 changes: 243 additions & 40 deletions src/shell/commands/head.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,47 +81,147 @@ fn copy_lines<F: FnMut(&mut [u8]) -> Result<usize>>(
Ok(ExecuteResult::from_exit_code(0))
}

fn copy_all_but_last_lines<F: FnMut(&mut [u8]) -> Result<usize>>(
writer: &mut ShellPipeWriter,
skip_last: u64,
kill_signal: &KillSignal,
mut read: F,
) -> Result<ExecuteResult> {
// read all content first
let mut content = Vec::new();
let mut buffer = vec![0; 512];
loop {
if let Some(exit_code) = kill_signal.aborted_code() {
return Ok(ExecuteResult::from_exit_code(exit_code));
}
let read_bytes = read(&mut buffer)?;
if read_bytes == 0 {
break;
}
content.extend_from_slice(&buffer[..read_bytes]);
}

// count total lines
let total_lines = content.iter().filter(|&&b| b == b'\n').count() as u64;

// output all but the last N lines
if total_lines <= skip_last {
return Ok(ExecuteResult::from_exit_code(0));
}
let lines_to_print = total_lines - skip_last;

let mut line_count = 0u64;
let mut start = 0;
for (i, &b) in content.iter().enumerate() {
if let Some(exit_code) = kill_signal.aborted_code() {
return Ok(ExecuteResult::from_exit_code(exit_code));
}
if b == b'\n' {
line_count += 1;
if line_count <= lines_to_print {
writer.write_all(&content[start..=i])?;
}
start = i + 1;
if line_count >= lines_to_print {
break;
}
}
}

Ok(ExecuteResult::from_exit_code(0))
}

fn execute_head(mut context: ShellCommandContext) -> Result<ExecuteResult> {
let flags = parse_args(&context.args)?;
if flags.path == "-" {
copy_lines(
&mut context.stdout,
flags.lines,
context.state.kill_signal(),
|buf| context.stdin.read(buf),
512,
)
} else {
let path = flags.path;
match File::open(context.state.cwd().join(path)) {
Ok(mut file) => copy_lines(
&mut context.stdout,
flags.lines,
context.state.kill_signal(),
|buf| file.read(buf).map_err(Into::into),
512,
),
Err(err) => {
context.stderr.write_line(&format!(
"head: {}: {}",
path.to_string_lossy(),
err
))?;
Ok(ExecuteResult::from_exit_code(1))
match flags.lines {
LineCount::First(max_lines) => {
if flags.path == "-" {
copy_lines(
&mut context.stdout,
max_lines,
context.state.kill_signal(),
|buf| context.stdin.read(buf),
512,
)
} else {
let path = flags.path;
match File::open(context.state.cwd().join(path)) {
Ok(mut file) => copy_lines(
&mut context.stdout,
max_lines,
context.state.kill_signal(),
|buf| file.read(buf).map_err(Into::into),
512,
),
Err(err) => {
context.stderr.write_line(&format!(
"head: {}: {}",
path.to_string_lossy(),
err
))?;
Ok(ExecuteResult::from_exit_code(1))
}
}
}
}
LineCount::AllButLast(skip_last) => {
if flags.path == "-" {
copy_all_but_last_lines(
&mut context.stdout,
skip_last,
context.state.kill_signal(),
|buf| context.stdin.read(buf),
)
} else {
let path = flags.path;
match File::open(context.state.cwd().join(path)) {
Ok(mut file) => copy_all_but_last_lines(
&mut context.stdout,
skip_last,
context.state.kill_signal(),
|buf| file.read(buf).map_err(Into::into),
),
Err(err) => {
context.stderr.write_line(&format!(
"head: {}: {}",
path.to_string_lossy(),
err
))?;
Ok(ExecuteResult::from_exit_code(1))
}
}
}
}
}
}

#[derive(Debug, PartialEq, Clone, Copy)]
enum LineCount {
/// print first N lines
First(u64),
/// print all but last N lines
AllButLast(u64),
}

#[derive(Debug, PartialEq)]
struct HeadFlags<'a> {
path: &'a OsStr,
lines: u64,
lines: LineCount,
}

fn parse_line_count(s: &str) -> Result<LineCount> {
if let Some(rest) = s.strip_prefix('-') {
let num = rest.parse::<u64>()?;
Ok(LineCount::AllButLast(num))
} else {
let num = s.parse::<u64>()?;
Ok(LineCount::First(num))
}
}

fn parse_args<'a>(args: &'a [OsString]) -> Result<HeadFlags<'a>> {
let mut path: Option<&'a OsStr> = None;
let mut lines: Option<u64> = None;
let mut lines: Option<LineCount> = None;
let mut iterator = parse_arg_kinds(args).into_iter();
while let Some(arg) = iterator.next() {
match arg {
Expand All @@ -137,9 +237,11 @@ fn parse_args<'a>(args: &'a [OsString]) -> Result<HeadFlags<'a>> {
}
ArgKind::ShortFlag('n') => match iterator.next() {
Some(ArgKind::Arg(arg)) => {
let num = arg.to_str().and_then(|a| a.parse::<u64>().ok());
if let Some(num) = num {
lines = Some(num);
if let Some(s) = arg.to_str() {
match parse_line_count(s) {
Ok(count) => lines = Some(count),
Err(_) => bail!("expected a numeric value following -n"),
}
} else {
bail!("expected a numeric value following -n")
}
Expand All @@ -150,7 +252,7 @@ fn parse_args<'a>(args: &'a [OsString]) -> Result<HeadFlags<'a>> {
if flag == "lines" || flag == "lines=" {
bail!("expected a value for --lines");
} else if let Some(arg) = flag.strip_prefix("lines=") {
lines = Some(arg.parse::<u64>()?);
lines = Some(parse_line_count(arg)?);
} else {
arg.bail_unsupported()?
}
Expand All @@ -161,7 +263,7 @@ fn parse_args<'a>(args: &'a [OsString]) -> Result<HeadFlags<'a>> {

Ok(HeadFlags {
path: path.unwrap_or(OsStr::new("-")),
lines: lines.unwrap_or(10),
lines: lines.unwrap_or(LineCount::First(10)),
})
}

Expand Down Expand Up @@ -231,56 +333,78 @@ mod test {
parse_args(&[]).unwrap(),
HeadFlags {
path: OsStr::new("-"),
lines: 10
lines: LineCount::First(10)
}
);
assert_eq!(
parse_args(&["-n".into(), "5".into()]).unwrap(),
HeadFlags {
path: OsStr::new("-"),
lines: 5
lines: LineCount::First(5)
}
);
assert_eq!(
parse_args(&["--lines=5".into()]).unwrap(),
HeadFlags {
path: OsStr::new("-"),
lines: 5
lines: LineCount::First(5)
}
);
assert_eq!(
parse_args(&["path".into()]).unwrap(),
HeadFlags {
path: OsStr::new("path"),
lines: 10
lines: LineCount::First(10)
}
);
assert_eq!(
parse_args(&["-n".into(), "5".into(), "path".into()]).unwrap(),
HeadFlags {
path: OsStr::new("path"),
lines: 5
lines: LineCount::First(5)
}
);
assert_eq!(
parse_args(&["--lines=5".into(), "path".into()]).unwrap(),
HeadFlags {
path: OsStr::new("path"),
lines: 5
lines: LineCount::First(5)
}
);
assert_eq!(
parse_args(&["path".into(), "-n".into(), "5".into()]).unwrap(),
HeadFlags {
path: OsStr::new("path"),
lines: 5
lines: LineCount::First(5)
}
);
assert_eq!(
parse_args(&["path".into(), "--lines=5".into()]).unwrap(),
HeadFlags {
path: OsStr::new("path"),
lines: 5
lines: LineCount::First(5)
}
);
// negative line counts (all but last N)
assert_eq!(
parse_args(&["-n".into(), "-1".into()]).unwrap(),
HeadFlags {
path: OsStr::new("-"),
lines: LineCount::AllButLast(1)
}
);
assert_eq!(
parse_args(&["-n".into(), "-5".into(), "path".into()]).unwrap(),
HeadFlags {
path: OsStr::new("path"),
lines: LineCount::AllButLast(5)
}
);
assert_eq!(
parse_args(&["--lines=-3".into()]).unwrap(),
HeadFlags {
path: OsStr::new("-"),
lines: LineCount::AllButLast(3)
}
);
assert_eq!(
Expand All @@ -304,4 +428,83 @@ mod test {
"unsupported flag: -t"
);
}

#[tokio::test]
async fn copies_all_but_last_lines() {
let (reader, mut writer) = pipe();
let reader_handle = reader.pipe_to_string_handle();
let data = b"line1\nline2\nline3\nline4\nline5\n";
let mut offset = 0;
let result = copy_all_but_last_lines(
&mut writer,
1,
&KillSignal::default(),
|buffer| {
if offset >= data.len() {
return Ok(0);
}
let read_length = min(buffer.len(), data.len() - offset);
buffer[..read_length]
.copy_from_slice(&data[offset..(offset + read_length)]);
offset += read_length;
Ok(read_length)
},
);
drop(writer);
assert_eq!(reader_handle.await.unwrap(), "line1\nline2\nline3\nline4\n");
assert_eq!(result.unwrap().into_exit_code_and_handles().0, 0);
}

#[tokio::test]
async fn copies_all_but_last_two_lines() {
let (reader, mut writer) = pipe();
let reader_handle = reader.pipe_to_string_handle();
let data = b"line1\nline2\nline3\nline4\nline5\n";
let mut offset = 0;
let result = copy_all_but_last_lines(
&mut writer,
2,
&KillSignal::default(),
|buffer| {
if offset >= data.len() {
return Ok(0);
}
let read_length = min(buffer.len(), data.len() - offset);
buffer[..read_length]
.copy_from_slice(&data[offset..(offset + read_length)]);
offset += read_length;
Ok(read_length)
},
);
drop(writer);
assert_eq!(reader_handle.await.unwrap(), "line1\nline2\nline3\n");
assert_eq!(result.unwrap().into_exit_code_and_handles().0, 0);
}

#[tokio::test]
async fn copies_all_but_last_lines_when_skip_exceeds_total() {
let (reader, mut writer) = pipe();
let reader_handle = reader.pipe_to_string_handle();
let data = b"line1\nline2\n";
let mut offset = 0;
let result = copy_all_but_last_lines(
&mut writer,
5,
&KillSignal::default(),
|buffer| {
if offset >= data.len() {
return Ok(0);
}
let read_length = min(buffer.len(), data.len() - offset);
buffer[..read_length]
.copy_from_slice(&data[offset..(offset + read_length)]);
offset += read_length;
Ok(read_length)
},
);
drop(writer);
// when skip_last >= total_lines, output should be empty
assert_eq!(reader_handle.await.unwrap(), "");
assert_eq!(result.unwrap().into_exit_code_and_handles().0, 0);
}
}
Loading