Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ Changelog
NOTE: isort follows the [semver](https://semver.org/) versioning standard.
Find out more about isort's release policy [here](https://pycqa.github.io/isort/docs/major_releases/release_policy).

### 5.13.3 (Unreleased)

- Fixed #2393: Calling isort with --sort-reexports with input from stdin fails due to non-seekable streams @jasur-py

### 5.13.2 December 13 2023

- Apply the bracket fix from issue #471 only for use_parentheses=True (#2184) @bp72
Expand Down
159 changes: 123 additions & 36 deletions isort/core.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import textwrap
from io import StringIO
import textwrap
from itertools import chain
from typing import List, TextIO, Union
import sys

import isort.literal
from isort.settings import DEFAULT_CONFIG, Config
Expand Down Expand Up @@ -52,6 +53,38 @@ def process(
Returns `True` if there were changes that needed to be made (errors present) from what
was provided in the input_stream, otherwise `False`.
"""
# Check if output stream is seekable for reexport handling
output_seekable = False
# Explicitly treat sys.stdout and sys.stderr as non-seekable
if output_stream in (sys.stdout, sys.stderr):
output_seekable = False
elif hasattr(output_stream, 'seekable'):
try:
output_seekable = output_stream.seekable()
if output_seekable:
# Try a test seek to see if it actually works
pos = output_stream.tell()
try:
output_stream.seek(pos)
except Exception:
output_seekable = False
except Exception:
output_seekable = False
elif all(hasattr(output_stream, attr) for attr in ('seek', 'tell', 'truncate')):
try:
pos = output_stream.tell()
output_stream.seek(pos)
output_seekable = True
except Exception:
output_seekable = False
# Use internal buffer if output stream is not seekable and we might need reexport sorting
internal_output = None
if not output_seekable and config.sort_reexports:
internal_output = StringIO()
_output_stream = internal_output
else:
_output_stream = output_stream

line_separator: str = config.line_ending
add_imports: List[str] = [format_natural(addition) for addition in config.add_imports]
import_section: str = ""
Expand Down Expand Up @@ -134,28 +167,44 @@ def process(
if not line_separator:
line_separator = "\n"

if code_sorting and code_sorting_section:
if is_reexport:
output_stream.seek(output_stream.tell() - reexport_rollback)
if code_sorting and code_sorting_section and is_reexport:
if output_seekable:
_output_stream.seek(_output_stream.tell() - reexport_rollback)
reexport_rollback = 0
sorted_code = textwrap.indent(
isort.literal.assignment(
code_sorting_section,
str(code_sorting),
extension,
config=_indented_config(config, indent),
),
code_sorting_indent,
)
made_changes = made_changes or _has_changed(
before=code_sorting_section,
after=sorted_code,
line_separator=line_separator,
ignore_whitespace=config.ignore_whitespace,
)
output_stream.write(sorted_code)
if is_reexport:
output_stream.truncate()
else:
if not output_seekable and reexport_rollback > 0:
current_value = _output_stream.getvalue()
# Find the last occurrence of '__all__' and truncate to its index
idx = current_value.rfind('__all__')
if idx != -1:
# Truncate to the start of the line containing __all__
line_start = current_value.rfind('\n', 0, idx)
if line_start == -1:
line_start = 0
else:
line_start += 1
_output_stream.seek(0)
_output_stream.truncate(0)
_output_stream.write(current_value[:line_start])
reexport_rollback = 0
sorted_code = textwrap.indent(
isort.literal.assignment(
code_sorting_section,
str(code_sorting),
extension,
config=_indented_config(config, indent),
),
code_sorting_indent,
)
made_changes = made_changes or _has_changed(
before=code_sorting_section,
after=sorted_code,
line_separator=line_separator,
ignore_whitespace=config.ignore_whitespace,
)
_output_stream.write(sorted_code)
if is_reexport:
_output_stream.truncate()
else:
stripped_line = line.strip()
if stripped_line and not line_separator:
Expand Down Expand Up @@ -239,6 +288,21 @@ def process(
code_sorting_indent = line[: -len(line.lstrip())]
not_imports = True
code_sorting_section += line
if is_reexport and not output_seekable and reexport_rollback > 0:
current_value = _output_stream.getvalue()
# Find the last occurrence of '__all__' and truncate to its index
idx = current_value.rfind('__all__')
if idx != -1:
# Truncate to the start of the line containing __all__
line_start = current_value.rfind('\n', 0, idx)
if line_start == -1:
line_start = 0
else:
line_start += 1
_output_stream.seek(0)
_output_stream.truncate(0)
_output_stream.write(current_value[:line_start])
reexport_rollback = 0
reexport_rollback = len(line)
is_reexport = True
elif code_sorting:
Expand All @@ -259,11 +323,29 @@ def process(
ignore_whitespace=config.ignore_whitespace,
)
if is_reexport:
output_stream.seek(output_stream.tell() - reexport_rollback)
reexport_rollback = 0
output_stream.write(sorted_code)
if output_seekable:
_output_stream.seek(_output_stream.tell() - reexport_rollback)
reexport_rollback = 0
else:
if not output_seekable and reexport_rollback > 0:
current_value = _output_stream.getvalue()
# Find the last occurrence of '__all__'
# and truncate to its index
idx = current_value.rfind('__all__')
if idx != -1:
# Truncate to the start of the line containing __all__
line_start = current_value.rfind('\n', 0, idx)
if line_start == -1:
line_start = 0
else:
line_start += 1
_output_stream.seek(0)
_output_stream.truncate(0)
_output_stream.write(current_value[:line_start])
reexport_rollback = 0
_output_stream.write(sorted_code)
if is_reexport:
output_stream.truncate()
_output_stream.truncate()
not_imports = True
code_sorting = False
code_sorting_section = ""
Expand All @@ -277,7 +359,7 @@ def process(
or stripped_line in config.section_comments_end
):
if import_section and not contains_imports:
output_stream.write(import_section)
_output_stream.write(import_section)
import_section = line
not_imports = False
else:
Expand Down Expand Up @@ -367,7 +449,7 @@ def process(
lines_before += line
continue
if not import_section:
output_stream.write("".join(lines_before))
_output_stream.write("".join(lines_before))
lines_before = []

raw_import_section: str = import_section
Expand All @@ -384,7 +466,7 @@ def process(
add_line_separator = line_separator or "\n"
import_section = add_line_separator.join(add_imports) + add_line_separator
if end_of_file and index != 0:
output_stream.write(add_line_separator)
_output_stream.write(add_line_separator)
contains_imports = True
add_imports = []

Expand All @@ -404,7 +486,7 @@ def process(
import_section += line
raw_import_section += line
if not contains_imports:
output_stream.write(import_section)
_output_stream.write(import_section)

else:
leading_whitespace = import_section[: -len(import_section.lstrip())]
Expand Down Expand Up @@ -444,12 +526,12 @@ def process(
line_separator=line_separator,
ignore_whitespace=config.ignore_whitespace,
)
output_stream.write(sorted_import_section)
_output_stream.write(sorted_import_section)
if not line and not indent and next_import_section:
output_stream.write(line_separator)
_output_stream.write(line_separator)

if indent:
output_stream.write(line)
_output_stream.write(line)
if not next_import_section:
indent = ""

Expand All @@ -461,7 +543,7 @@ def process(
import_section = next_import_section
next_import_section = ""
else:
output_stream.write(line)
_output_stream.write(line)
not_imports = False

if stripped_line and not in_quote and not import_section and not next_import_section:
Expand All @@ -471,7 +553,7 @@ def process(
if not new_line:
break

output_stream.write(new_line)
_output_stream.write(new_line)
stripped_line = new_line.strip().split("#")[0]

if stripped_line.startswith(("raise", "yield")):
Expand All @@ -480,13 +562,18 @@ def process(
if not new_line:
break

output_stream.write(new_line)
_output_stream.write(new_line)
stripped_line = new_line.strip().split("#")[0]

if made_changes and config.only_modified:
for output_str in verbose_output:
print(output_str)

# Write internal buffer to actual output stream if we used one
if internal_output is not None:
internal_output.seek(0)
output_stream.write(internal_output.read())

return made_changes


Expand Down
83 changes: 83 additions & 0 deletions tests/unit/test_isort.py
Original file line number Diff line number Diff line change
Expand Up @@ -5741,3 +5741,86 @@ def test_reexport_multiline_long_rollback() -> None:
test
"""
assert isort.code(test_input, config=Config(sort_reexports=True)) == expd_output


def test_reexport_non_seekable_stream() -> None:
"""Test that reexport sorting works with non-seekable streams like stdout"""
from io import StringIO

test_input = """from test import B, A
__all__ = ["B", "A"]"""

expected_output = """from test import A, B

__all__ = ['A', 'B']"""

# Test with a non-seekable stream (simulating stdout)
input_stream = StringIO(test_input)
output_stream = StringIO()

# Mock sys.stdout to be non-seekable
original_stdout = sys.stdout
try:
sys.stdout = output_stream
api.sort_stream(
input_stream=input_stream,
output_stream=output_stream,
config=Config(sort_reexports=True),
)
output_stream.seek(0)
result = output_stream.read()
assert result == expected_output
finally:
sys.stdout = original_stdout

def test_reexport_non_seekable_stream() -> None:
"""Test that reexport sorting works with non-seekable streams like stdout"""
from io import StringIO

test_input = """from test import B, A
__all__ = ["B", "A"]"""

expected_output = """from test import A, B

__all__ = ['A', 'B']"""

# Test with a non-seekable stream (simulating stdout)
input_stream = StringIO(test_input)

# Create a non-seekable output stream that allows reading the result
class NonSeekableStream(StringIO):
def __init__(self):
super().__init__()
self._allow_seek = False

def seek(self, *args, **kwargs):
if not self._allow_seek:
raise OSError("Stream is not seekable")
return super().seek(*args, **kwargs)

def tell(self, *args, **kwargs):
if not self._allow_seek:
raise OSError("Stream is not seekable")
return super().tell(*args, **kwargs)

def truncate(self, *args, **kwargs):
if not self._allow_seek:
raise OSError("Stream is not seekable")
return super().truncate(*args, **kwargs)

def allow_seek(self):
self._allow_seek = True

output_stream = NonSeekableStream()

api.sort_stream(
input_stream=input_stream,
output_stream=output_stream,
config=Config(sort_reexports=True),
)

# Allow seeking to read the result
output_stream.allow_seek()
output_stream.seek(0)
result = output_stream.read()
assert result == expected_output