-
Notifications
You must be signed in to change notification settings - Fork 7
improve pdf file created by FOM #327
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @caglayantuna !
I made some suggestion to improve the code readability.
Also, it would be nice to add some unit tests for these new functions.
/ f"{parameter_name}.csv" | ||
) | ||
if os.path.exists(file_path): | ||
file_to_plot = file_path |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lines 191 to 211 can be replaced by:
files_to_plot = [
f for f in self.path_save_model_parameters_convergence.glob(f"{parameter_name}*.csv")
]
If needed the regexp can be made more robust but I think this should be enough.
file_to_plot = file_path | ||
|
||
if any( | ||
f"_{j}" in str(file) for j, file in enumerate(files_to_plot) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My previous suggestion gives you a list of files matching your query. This list can be of length zero, one or more. I suggest you handle these three cases rather than this strange condition.
# in this function you'd have:
for i, parameter_name in enumerate(params_to_plot[page : page + 6]):
files_to_plot = [
f for f in self.path_save_model_parameters_convergence.glob(f"{parameter_name}*.csv")
]
_plot_parameter(files_to_plot, ax[i])
# with helper functions:
def _plot_parameter(files_to_plot: Sequence[Path], ax: Axes):
if len(files_to_plot) == 0:
# Do something...
elif len(files_to_plot) == 1:
_plot_single_parameter_file(files_to_plot[0], ax)
else:
_plot_multiple_parameter_file(files_to_plot, ax)
def _plot_single_parameter_file(file_to_plot: Path, ax: Axes) -> None:
df = pd.read_csv(file_to_plot, index_col=0, header=None)
ax.plot(df)
def _plot_multiple_parameter_file(files_to_plot: Iterable[Path], ax: Axes) -> None:
colormap = cm.viridis(np.linspace(0, 1, len(files_to_plot)))
for idx, file in enumerate(files_to_plot):
df = pd.read_csv(file, index_col=0, header=None)
line = ax.plot(df, color=colormap[idx])[0]
lines.append(line)
ax.legend(lines, [f.stem for f in files_to_plot], loc="best")
Does that make sense ?
Thanks @NicolasGensollen for your comments. I changed this PR after some discussions with the team. Review is not necessary for now. I am attaching latest pdf results to discuss about them. |
78e44b5
to
65beca5
Compare
This PR is ready for reviewing. I added a test function for Also, I added |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @caglayantuna
I made a first pass. My general comment is that the function you added is very very long, and hence difficult to read and understand as a whole.
I think the code would read much better by extracting some logic in smaller functions.
|
||
params_to_plot = list(model.state.tracked_variables - to_skip) | ||
|
||
files = os.listdir(self.path_save_model_parameters_convergence) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.path_save_model_parameters_convergence
should be a Path object and not a string since it is a path.
Also, the logic used to get the files is duplicated, I'd turn it into a function:
def _get_files_related_to_parameters(self, parameters: Iterable[str]) - > list[Path]:
return [
f for f in self.path_save_model_parameters_convergence.iterdir()
if any(f.name.startswith(param) for param in parameters)
]
So getting the files to plot here would become:
files_to_plot = self._get_files_related_to_parameters(params_to_plot).sort()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Nicolas I added several helper functions to simplify this function
# If plot sourcewise is true, new sourcewise csv files will be created | ||
if self.plot_sourcewise: | ||
new_files = [] | ||
for param_name in params_with_sources: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would become:
related_files = self._get_files_related_to_parameters(params_with_sources)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
for source_idx in range(num_sources): | ||
combined_data = [] | ||
|
||
for file_name in related_files: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would also be simplified since the list holds the full path:
for file_path in related_files:
...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, I did
combined_data.append(df.iloc[:, source_idx]) | ||
|
||
combined_df = pd.concat(combined_data, axis=1, join="inner") | ||
new_file_name = "sourcewise_" f"{param_name}_{source_idx + 1}.csv" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks weird, is this what you wanted to write:
new_file_name = "sourcewise_" f"{param_name}_{source_idx + 1}.csv" | |
new_file_name = f"sourcewise_{param_name}_{source_idx + 1}.csv" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks, since I added sourcewise
later, I forgot to change "f". now , it is ok
_, ax = plt.subplots(3, 2, figsize=(width, 3 * height_per_row)) | ||
ax = ax.flatten() | ||
|
||
for i, file_name in enumerate(files_to_plot[page : page + 6]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for i, file_name in enumerate(files_to_plot[page : page + 6]): | |
for i, file_path in enumerate(files_to_plot[page : page + 6]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, done
ax = ax.flatten() | ||
|
||
for i, file_name in enumerate(files_to_plot[page : page + 6]): | ||
file_path = self.path_save_model_parameters_convergence / file_name |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
file_path = self.path_save_model_parameters_convergence / file_name |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I removed it
|
||
for i, file_name in enumerate(files_to_plot[page : page + 6]): | ||
file_path = self.path_save_model_parameters_convergence / file_name | ||
parameter_name = file_name.split(".csv")[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
parameter_name = file_name.split(".csv")[0] | |
parameter_name = file_path.name.split(".csv")[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
parameter_name = parameter_name[:-3] | ||
else: | ||
feature_index = int(parameter_name[-1]) | ||
parameter_name = parameter_name[:-2] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is very fragile and not generalizable. You need to use regular expressions for these kind of tasks.
Here is a proposition which could still be improved:
def _extract_parameter_name_and_index(parameter_name: str) -> tuple[Optional[str], Optional[int]]:
match = re.search(r'^(.*?)(\d+)$', parameter_name)
if match:
return match.group(1).strip("_"), int(match.group(2))
return None, None
Which gives:
>>> _extract_parameter_name_and_index("mixing_matrix_10")
('mixing_matrix', 10)
>>> _extract_parameter_name_and_index("mixing_matrix_1")
('mixing_matrix', 1)
>>> _extract_parameter_name_and_index("mixing_matrix_10400")
('mixing_matrix', 10400)
>>> _extract_parameter_name_and_index("mixing_matrix_10400_foo")
(None, None)
>>> _extract_parameter_name_and_index("mixing_ma")
(None, None)
>>> _extract_parameter_name_and_index("100_mixing_ma")
(None, None)
>>> _extract_parameter_name_and_index("v0_12")
('v0', 12)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you very much for your suggestions, it is very useful. However, I couldn't find a solution when input is v0
. This is a parameter name and 0
is not index. To fix this, I had to add specific condition for this parameter. If you have any suggestion, would be nice.
df_convergence = pd.read_csv(file_path, index_col=0, header=None) | ||
ax[i].plot(df_convergence) | ||
|
||
if parameter_name == "mixing_matrix": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks very hardcoded, would there be a way to have something more generic ?
Also, could you move that to a separate function, this function is enormous and hard to follow along. Something like this maybe:
def _set_title_for_parameter(ax, parameter_name: str):
if parameter_name == "mixing_matrix":
ax.set_title(...)
...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Normally, we could use params_with_sources
instead of hardcoding. However, zeta
in joint model has different structure. I couldn't find a better solution but I created a function as you suggested.
65beca5
to
ea3bae0
Compare
Thanks @NicolasGensollen for your suggestions. They are really helpful. I changed the PR accordingly, but I still have some questions that you will see in the comments |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @caglayantuna !
I made another pass with more suggestions. Most of them go in the same direction of breaking down huge functions into smaller units which are easier to understand. The base idea of doing so, is that the multi-line operations will be abstracted into the function name in your brain. If the function has no nasty side effects, does only one thing, and has a well-chosen name, it helps a lot and reduces the amount of objects your brain needs to keep track of.
Btw, feel free to improve names if you have better ideas !
ax[i].legend(events, loc="best") | ||
return ax | ||
|
||
def _update_files_sourcewise(self, files_to_plot, params_with_sources, num_sources): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
type hints
Update the list of files by creating sourcewise files based on the parameters with sources. | ||
|
||
This function processes parameters that have sources and generates new CSV files for each source. | ||
These new files are added to the list of files to plot. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Two small comments here:
First of all, this function is doing multiple things, and it's clearly visible when reading the description. Because of that, it's long, and not easy to understand.
Second, it is mutating the files_to_plot
list for no reason. A safer implementation would compute the new_files
list and return that. The caller would then perform the aggregation of the files_to_plot
and new_files
. Doing that, this function doesn't even need the files_to_plot
list. It could basically look like that:
def _compute_files_sourcewise(self, params_with_sources: list[str], num_sources: int) -> list[Path]:
new_files = []
for param_name in params_with_sources:
...
return new_files
The caller:
files_to_plot = [
file
for file in files_to_plot
if not any(file.name.startswith(param) for param in params_with_sources)
]
files_to_plot.extend(self._compute_files_sourcewise(...))
WDYT ?
for source_idx in range(num_sources): | ||
combined_data = [] | ||
|
||
for file_path in related_files: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have three imbricated for loop in this function. It's another sign that we are trying to do too much at once.
For each source, we read the related files one by one and concatenate everything together, this concatenation is then saved to disk and the path is added to the new_files
list.
Here is a suggestion of how we could refactor this to help readability:
...
for source_idx in range(num_sources):
new_files.append(
_concatenate_and_save_sourcewise_parameters(
param_name, related_files, source_idx
)
)
...
def _concatenate_and_save_sourcewise_parameters(
self, parameter_name: str, files: Iterable[Path], source_idx: int
) -> Path:
"""Explain what I do..."""
df = _concatenate_parameters(files, source_idx)
file_name = f"sourcewise_{parameter_name}_{source_idx + 1}.csv"
file_path = self.path_save_model_parameters_convergence / file_name
df.to_csv(file_path, header=False)
return file_path
def _concatenate_parameters(files: Iterable[Path], source_idx: int) -> pd.DataFrame:
"""Explain me..."""
return pd.concat(
[
pd.read_csv(file_path, index_col=0, header=None).iloc[:, source_idx]
for file_path in files
],
axis=1,
join="inner",
)
improve doc and function improve FOM improve fom with tests add local and sourwise option
3d89817
to
ce0bda7
Compare
Co-authored-by: Gensollen <[email protected]>
Hello @NicolasGensollen , thanks for your suggestions again. I changed the PR accordingly. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks @caglayantuna !
This PR solves the issue #123 .
To do this, I improved the
save_plot_convergence_model_parameters
function. Now, it creates a pdf with several pages (6 plots per page) and it plots some additional parameters such as mixing matrix for all models and zeta from joint model. Since these new parameters have multiple files, I have also added legends for them specifically.I add an example pdf which was created after fitting data with joint model.
convergence_parameters.pdf