Update chatbot/train.py and chatbot/app.py to improve model performance#200
Update chatbot/train.py and chatbot/app.py to improve model performance#200baranwalayush wants to merge 1 commit intoAOSSIE-Org:mainfrom
Conversation
WalkthroughThe chatbot module is enhanced with NLTK-based NLP standardization (word_tokenize and PorterStemmer) integrated into both training and inference pipelines. Training configuration is externalized into constants, epoch-level loss monitoring is added, and the dataset now returns PyTorch tensors instead of raw numpy arrays for consistency. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Areas requiring attention:
Poem
Pre-merge checks and finishing touches❌ Failed checks (2 warnings)
✅ Passed checks (3 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
chatbot/train.py (2)
84-91: Type mismatch:bag_of_wordsreturns tensors, but code expects numpy arrays.
bag_of_wordsreturns atorch.tensor(line 32), but these are collected into a list and converted tonp.array(line 90). This creates a numpy object array of tensors. Later,__getitem__(line 111) callstorch.from_numpy()on these tensor objects, which will fail becausefrom_numpyexpects a numpy ndarray, not a PyTorch tensor.Either convert
bag_of_wordsoutput to numpy before appending, or handle tensors directly:Option 1: Convert to numpy in bag_of_words (consistent with original design)
def bag_of_words(tokenized_sentence, words): sentence_words = [stem(word) for word in tokenized_sentence] bag = [1.0 if word in sentence_words else 0.0 for word in words] - return torch.tensor(bag, dtype=torch.float32) + return np.array(bag, dtype=np.float32)Option 2: Stack tensors directly (avoid numpy conversion)
-X_train = np.array(X_train) +X_train = torch.stack(X_train) y_train = np.array(y_train)And update
__getitem__:def __getitem__(self, index): - return torch.from_numpy(self.x_data[index]), torch.tensor(self.y_data[index]) + return self.x_data[index], torch.tensor(self.y_data[index])
117-121:num_workers=0does not meet the linked issue requirement.Issue #198 specifically requested "set DataLoader
num_workers > 0to enable parallel data fetching" for performance improvement. The current value is0, which disables parallel loading.train_loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, - num_workers=0) + num_workers=2)Note: If Windows compatibility is a concern, consider using a conditional:
import os num_workers = 0 if os.name == 'nt' else 2
🧹 Nitpick comments (2)
chatbot/train.py (2)
8-11: Consider guarding NLTK resource download.Calling
nltk.download('punkt')at module import time can cause unexpected network calls in production. Consider wrapping this in a try/except that checks if the resource already exists, or move it to a setup/initialization function.import nltk from nltk.stem.porter import PorterStemmer -nltk.download('punkt') +try: + nltk.data.find('tokenizers/punkt') +except LookupError: + nltk.download('punkt')
20-32: Code duplication withapp.py.The
tokenize,stem, andbag_of_wordsfunctions are duplicated inapp.py(lines 14-24). Consider extracting these into a shared module (e.g.,nlp_utils.py) to ensure consistency and reduce maintenance burden.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
chatbot/app.py(3 hunks)chatbot/train.py(7 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
chatbot/app.py (1)
chatbot/train.py (3)
stem(24-25)tokenize(21-22)bag_of_words(27-32)
chatbot/train.py (1)
chatbot/app.py (3)
stem(18-19)tokenize(15-16)bag_of_words(21-24)
🪛 Ruff (0.14.8)
chatbot/app.py
96-96: Possible binding to all interfaces
(S104)
96-96: Use of debug=True in Flask app detected
(S201)
🔇 Additional comments (6)
chatbot/train.py (3)
13-18: Good use of configuration constants.Extracting
INTENTS_FILEandMODEL_SAVE_FILEinto constants improves maintainability. The global stemmer instance aligns with the approach inapp.py.
132-159: Good implementation of epoch-averaged loss tracking.The weighted loss accumulation (
loss.item() * words.size(0)) and division bylen(dataset)correctly computes the average loss per sample across the epoch, fulfilling the linked issue requirement.
161-172: LGTM!Model saving correctly uses the
MODEL_SAVE_FILEconstant and includes all necessary metadata for inference.chatbot/app.py (3)
8-24: LGTM - NLP preprocessing aligns with training.The tokenization and stemming implementation correctly mirrors
train.py, ensuring consistent preprocessing between training and inference.
47-48: Verifytorch.loadbehavior with current PyTorch version.In PyTorch 2.6+,
torch.loaddefaults toweights_only=True, which would fail here since the saved data includes non-tensor objects (all_words,tags). Consider explicitly setting the parameter for forward compatibility.FILE = "data.pth" -data = torch.load(FILE) +data = torch.load(FILE, weights_only=False)Alternatively, for better security, use
torch.load(FILE, weights_only=True, map_location=device)and save metadata separately in a JSON file.
63-93: Chat endpoint logic is correct.The inference pipeline properly uses the updated NLP functions and handles uncertain predictions gracefully.
One minor consideration: returning raw exception messages (
str(e)) in line 93 could expose internal details. In production, consider logging the full error server-side and returning a generic message to clients.
| if __name__ == '__main__': | ||
| app.run(host="0.0.0.0", port=5000,debug=True) | ||
| app.run(host="0.0.0.0", port=5000, debug=True) |
There was a problem hiding this comment.
Development server configuration flagged by static analysis.
debug=True and binding to 0.0.0.0 are appropriate for development but pose security risks in production:
debug=Trueenables the interactive debugger, which can execute arbitrary code0.0.0.0exposes the service to all network interfaces
Consider using environment variables or a configuration flag:
+import os
+
if __name__ == '__main__':
- app.run(host="0.0.0.0", port=5000, debug=True)
+ debug_mode = os.environ.get('FLASK_DEBUG', 'false').lower() == 'true'
+ app.run(host="0.0.0.0", port=5000, debug=debug_mode)For production, use a WSGI server (e.g., Gunicorn) instead of the Flask development server.
🧰 Tools
🪛 ast-grep (0.40.0)
[warning] 95-95: Running flask app with host 0.0.0.0 could expose the server publicly.
Context: app.run(host="0.0.0.0", port=5000, debug=True)
Note: [CWE-668]: Exposure of Resource to Wrong Sphere [OWASP A01:2021]: Broken Access Control [REFERENCES]
https://owasp.org/Top10/A01_2021-Broken_Access_Control
(avoid_app_run_with_bad_host-python)
[warning] 95-95: Detected Flask app with debug=True. Do not deploy to production with this flag enabled as it will leak sensitive information. Instead, consider using Flask configuration variables or setting 'debug' using system environment variables.
Context: app.run(host="0.0.0.0", port=5000, debug=True)
Note: [CWE-489] Active Debug Code. [REFERENCES]
- https://labs.detectify.com/2015/10/02/how-patreon-got-hacked-publicly-exposed-werkzeug-debugger/
(debug-enabled-python)
🪛 Ruff (0.14.8)
96-96: Possible binding to all interfaces
(S104)
96-96: Use of debug=True in Flask app detected
(S201)
🤖 Prompt for AI Agents
In chatbot/app.py around lines 95-96, the app is started with debug=True and
host="0.0.0.0", which is unsafe for production; change the startup to read
configuration from environment variables (e.g., FLASK_ENV or APP_DEBUG and
APP_HOST/APP_PORT), default debug to False, and only enable debug and binding to
0.0.0.0 when an explicit development flag is set; for production remove the
built-in server invocation and run under a WSGI server like Gunicorn (document
example command) so the code no longer enables the interactive debugger or
indiscriminately binds all interfaces by default.
| @@ -1,23 +1,34 @@ | |||
| from operator import index | |||
There was a problem hiding this comment.
Remove unused import.
The index import from operator is never used in this file.
-from operator import index📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| from operator import index |
🤖 Prompt for AI Agents
In chatbot/train.py at line 1, the file imports `index` from the standard
library `operator` module but never uses it; remove the unused import line `from
operator import index` to clean up imports and avoid lint warnings, ensuring
there are no other references to `index` elsewhere in the file before committing
the change.
Description
In this PR, I have
chatbot/app.pyto work with the improvedchatbot/train.pyLet me know if there is anything else I need to take care of.
Fixes #198
Type of change
Please mark the options that are relevant.
Checklist:
Summary by CodeRabbit
New Features
Refactor
✏️ Tip: You can customize this high-level summary in your review settings.