-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Reorganize dependencies and enhance MAML visualization #41
Conversation
Reviewer's Guide by SourceryThis PR focuses on two main areas: reorganizing the project dependencies in requirements.txt and enhancing the codebase with new visualization capabilities and social network analysis features. The requirements.txt has been restructured with clear categorization and proper PyTorch dependency ordering. The code changes introduce new visualization functionality for MAML results and add a new Graph Attention Network (GAT) implementation for social network analysis. Sequence diagram for enhanced MAML visualizationsequenceDiagram
participant User
participant MAML as MAML Model
participant OS
participant Matplotlib
participant Logger
User->>MAML: Call visualize_adaptation()
MAML->>OS: Create 'maml_results' directory
OS-->>MAML: Directory created
MAML->>Matplotlib: Generate plots
Matplotlib-->>MAML: Return plot
MAML->>OS: Save plot to 'maml_results'
OS-->>MAML: Plot saved
MAML->>Logger: Log 'Saved visualization to ...'
Class diagram for the new GAT social network analysisclassDiagram
class AgentNode {
+String agent_id
+Dict attributes
+Dict state
}
class RelationshipGraph {
+String device
+int hidden_dim
+int num_heads
+DiGraph nx_graph
+Linear node_encoder
+Linear edge_encoder
+GATConv gat_layer
+add_agent(AgentNode agent)
+update_relationship(String source_id, String target_id, float trust, float influence, float familiarity)
+compute_social_dynamics() Tensor
+analyze_communities() List
+get_influence_paths(String source_id, String target_id) List
+visualize(String title)
}
AgentNode --> RelationshipGraph
File-Level Changes
Tips and commandsInteracting with Sourcery
Customizing Your ExperienceAccess your dashboard to:
Getting Help
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @leonvanbokhorst - I've reviewed your changes - here's some feedback:
Overall Comments:
- Consider adding version pins (e.g. numpy==1.21.0) to requirements.txt to ensure reproducible builds. This is especially important for ML dependencies that can have breaking changes between versions.
Here's what I looked at during the review
- 🟡 General issues: 1 issue found
- 🟢 Security: all looks good
- 🟢 Testing: all looks good
- 🟡 Complexity: 1 issue found
- 🟢 Documentation: all looks good
Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.
@@ -299,7 +306,7 @@ def setup_visualization_plot(x_data, y_data, plot_title, legend_labels): | |||
f"Adaptation Improvement: {((F.mse_loss(initial_pred, query_y) - F.mse_loss(adapted_pred, query_y)) / F.mse_loss(initial_pred, query_y) * 100):.1f}%") | |||
|
|||
plt.tight_layout() | |||
save_path = f'adaptation_plot_{task_name.replace(" ", "_")}.png' | |||
save_path = os.path.join(results_dir, f'adaptation_plot_{task_name.replace(" ", "_")}.png') | |||
plt.savefig(save_path, dpi=300, bbox_inches='tight') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggestion: Consider adding error handling for file writing operations
The current error handling only covers the visualization creation but not the file operations. Consider wrapping the file operations in a try-except block to handle potential permission or disk space issues.
plt.savefig(save_path, dpi=300, bbox_inches='tight') | |
try: | |
plt.savefig(save_path, dpi=300, bbox_inches='tight') | |
except (OSError, PermissionError) as e: | |
raise RuntimeError(f"Failed to save plot to {save_path}") from e |
features_tensor = torch.stack(features) | ||
return edge_index_tensor, self.edge_encoder(features_tensor) | ||
|
||
def visualize(self, title: str = "Relationship Graph") -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
issue (complexity): Consider splitting the visualization method into smaller focused helper methods for nodes, edges and colorbars.
The visualize
method handles too many visualization concerns in one place. Consider splitting it into focused helper methods while keeping the main flow clear:
def visualize(self, title: str = "Relationship Graph") -> None:
"""Visualizes the relationship graph with edge weights and node attributes."""
fig, ax = plt.subplots(figsize=(12, 8))
pos = nx.spring_layout(self.nx_graph, k=1, iterations=50)
self._draw_nodes(ax, pos)
self._draw_edges(ax, pos)
self._add_colorbars(fig)
nx.draw_networkx_labels(self.nx_graph, pos, ax=ax)
ax.set_title(title)
ax.axis("off")
plt.show()
def _draw_nodes(self, ax, pos):
"""Handles node visualization logic."""
node_sizes = [1000 * data["attributes"]["extraversion"]
for _, data in self.nx_graph.nodes(data=True)]
node_colors = [data["attributes"]["openness"]
for _, data in self.nx_graph.nodes(data=True)]
nx.draw_networkx_nodes(
self.nx_graph, pos,
node_size=node_sizes,
node_color=node_colors,
cmap=plt.cm.viridis,
alpha=0.7,
ax=ax
)
def _draw_edges(self, ax, pos):
"""Handles edge visualization logic."""
edges, colors, widths = zip(*[
((u, v), data["influence"], data["trust"] * 2)
for u, v, data in self.nx_graph.edges(data=True)
])
nx.draw_networkx_edges(
self.nx_graph, pos,
edgelist=edges,
edge_color=colors,
edge_cmap=plt.cm.coolwarm,
width=widths,
edge_vmin=0,
edge_vmax=1,
arrows=True,
arrowsize=20,
ax=ax
)
This refactoring:
- Separates concerns into focused methods
- Makes the main visualization flow clear
- Keeps all functionality intact
- Makes the code more maintainable and testable
Summary by Sourcery
Update the requirements.txt to reorganize dependencies by category and add new dependencies for PyTorch Geometric. Enhance the MAML visualization process by creating a results directory and updating plot save paths. Introduce a new script for social network analysis using Graph Attention Networks. Clean up the repository by removing unused files and directories.
Enhancements:
Chores: