Skip to content

Commit 951890f

Browse files
committed
[JTH] binwaves modifications, standarize bug fixed and logger option to not write in console
1 parent f224f07 commit 951890f

File tree

17 files changed

+751
-76
lines changed

17 files changed

+751
-76
lines changed

bluemath_tk/core/logging.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
1+
import logging
12
import os
23
from datetime import datetime
3-
import pytz
4-
import logging
54
from typing import Union
65

6+
import pytz
7+
78

89
def get_file_logger(
9-
name: str, logs_path: str = None, level: Union[int, str] = "INFO"
10+
name: str,
11+
logs_path: str = None,
12+
level: Union[int, str] = "INFO",
13+
console: bool = True,
1014
) -> logging.Logger:
1115
"""
1216
Creates and returns a logger that writes log messages to a file.
@@ -16,9 +20,11 @@ def get_file_logger(
1620
name : str
1721
The name of the logger.
1822
logs_path : str, optional
19-
The file path where the log messages will be written (default is None).
23+
The file path where the log messages will be written. Default is None.
2024
level : Union[int, str], optional
21-
The logging level (default is "INFO").
25+
The logging level. Default is "INFO".
26+
console : bool
27+
Whether to add or not console / terminal logs. Default is True.
2228
2329
Returns
2430
-------
@@ -60,18 +66,19 @@ def get_file_logger(
6066
os.makedirs(os.path.dirname(logs_path))
6167
file_handler = logging.FileHandler(logs_path)
6268

63-
# Also ouput logs in the console
64-
console_handler = logging.StreamHandler()
65-
6669
# Define a logging format
6770
formatter = logging.Formatter(
6871
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
6972
)
7073
file_handler.setFormatter(formatter)
71-
console_handler.setFormatter(formatter)
7274

7375
# Add the file handler to the logger
7476
logger.addHandler(file_handler)
75-
logger.addHandler(console_handler)
77+
78+
# Also ouput logs in the console if requested
79+
if console:
80+
console_handler = logging.StreamHandler()
81+
console_handler.setFormatter(formatter)
82+
logger.addHandler(console_handler)
7683

7784
return logger

bluemath_tk/core/models.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,11 +64,12 @@ def logger(self) -> logging.Logger:
6464
def logger(self, value: logging.Logger) -> None:
6565
self._logger = value
6666

67-
def set_logger_name(self, name: str, level: str = "INFO") -> None:
67+
def set_logger_name(
68+
self, name: str, level: str = "INFO", console: bool = True
69+
) -> None:
6870
"""Sets the name of the logger."""
6971

70-
self.logger = get_file_logger(name=name)
71-
self.logger.setLevel(level)
72+
self.logger = get_file_logger(name=name, level=level, console=console)
7273

7374
def save_model(self, model_path: str, exclude_attributes: List[str] = None) -> None:
7475
"""Saves the model to a file."""
@@ -243,6 +244,7 @@ def standarize(
243244
self,
244245
data: Union[np.ndarray, pd.DataFrame, xr.Dataset],
245246
scaler: StandardScaler = None,
247+
transform: bool = False,
246248
) -> Tuple[Union[np.ndarray, pd.DataFrame, xr.Dataset], StandardScaler]:
247249
"""
248250
Standarize data using StandardScaler.
@@ -254,6 +256,8 @@ def standarize(
254256
Input data to be standarized.
255257
scaler : StandardScaler, optional
256258
Scaler object to use for standarization. Default is None.
259+
transform : bool
260+
Whether to just transform the data. Default to False.
257261
258262
Returns
259263
-------
@@ -263,7 +267,9 @@ def standarize(
263267
Scaler object used for standarization.
264268
"""
265269

266-
standarized_data, scaler = standarize(data=data, scaler=scaler)
270+
standarized_data, scaler = standarize(
271+
data=data, scaler=scaler, transform=transform
272+
)
267273
return standarized_data, scaler
268274

269275
def destandarize(

bluemath_tk/core/operations.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,7 @@ def denormalize(
205205
def standarize(
206206
data: Union[np.ndarray, pd.DataFrame, xr.Dataset],
207207
scaler: StandardScaler = None,
208+
transform: bool = False,
208209
) -> Tuple[Union[np.ndarray, pd.DataFrame, xr.Dataset], StandardScaler]:
209210
"""
210211
Standarize data to have mean 0 and std 1.
@@ -215,6 +216,8 @@ def standarize(
215216
Input data to be standarized.
216217
scaler : StandardScaler, optional
217218
Scaler object to use for standarization. Default is None.
219+
transform : bool
220+
Whether to just transform the data. Default to False.
218221
219222
Returns
220223
-------
@@ -233,13 +236,21 @@ def standarize(
233236

234237
scaler = scaler or StandardScaler()
235238
if isinstance(data, np.ndarray):
236-
standarized_data = scaler.fit_transform(X=data)
237-
return standarized_data, scaler
239+
if transform:
240+
standarized_data = scaler.transform(X=data)
241+
else:
242+
standarized_data = scaler.fit_transform(X=data)
238243
elif isinstance(data, pd.DataFrame):
239-
standarized_data = scaler.fit_transform(X=data.values)
244+
if transform:
245+
standarized_data = scaler.transform(X=data.values)
246+
else:
247+
standarized_data = scaler.fit_transform(X=data.values)
240248
standarized_data = pd.DataFrame(standarized_data, columns=data.columns)
241249
elif isinstance(data, xr.Dataset):
242-
standarized_data = scaler.fit_transform(X=data.to_array().values)
250+
if transform:
251+
standarized_data = scaler.transform(X=data.to_array().values)
252+
else:
253+
standarized_data = scaler.fit_transform(X=data.to_array().values)
243254
standarized_data = xr.Dataset(
244255
{
245256
var_name: (tuple(data.coords), standarized_data[i_var])

bluemath_tk/core/plotting/base_plotting.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,11 @@ def get_subplots(self, **kwargs):
111111
fig, ax = plt.subplots(**kwargs)
112112
return fig, ax
113113

114+
def get_subplot(self, figsize, **kwargs):
115+
fig = plt.figure(figsize=figsize)
116+
ax = fig.add_subplot(**kwargs)
117+
return fig, ax
118+
114119
def plot_line(self, ax, **kwargs):
115120
c = kwargs.get("c", self.line_defaults.get("color"))
116121
kwargs.pop("c", None)

bluemath_tk/core/plotting/colors.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
import cmocean
12
import matplotlib.colors as mcolors
2-
from matplotlib import cm
33
import numpy as np
4+
from matplotlib import cm
5+
from matplotlib import pyplot as plt
6+
from matplotlib.colors import ListedColormap
47

58
default_colors = [
69
"#636EFA",
@@ -232,3 +235,35 @@ def GetFamsColors(num_fams):
232235
np_colors_rgb = colors_interp(num_fams) # interpolate
233236

234237
return np_colors_rgb
238+
239+
240+
def colormap_bathy(topat, topag): # maximum topo, minimum bati
241+
"""
242+
returns custom colormap for bathymetry plot
243+
"""
244+
245+
colors2 = "YlGnBu_r"
246+
colors1 = cmocean.cm.turbid
247+
248+
bottom = plt.get_cmap(colors2, -topag * 100)
249+
top = plt.get_cmap(colors1, topat * 100)
250+
251+
newcolors = np.vstack(
252+
(
253+
bottom(np.linspace(0, 0.8, -topag * 100)),
254+
top(np.linspace(0.1, 1, topat * 100)),
255+
)
256+
)
257+
258+
return ListedColormap(newcolors)
259+
260+
261+
def colormap_spectra():
262+
top = cm.get_cmap("RdBu", 128)
263+
bottom = cm.get_cmap("rainbow", 128)
264+
newcolors = np.vstack(
265+
(top(np.linspace(0.5, 0.8, 50)), bottom(np.linspace(0.2, 1, 128)))
266+
)
267+
newcmp = ListedColormap(newcolors, name="newcmp")
268+
269+
return newcmp

bluemath_tk/datamining/pca.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
from typing import List, Union
22

3+
import cartopy.crs as ccrs
34
import numpy as np
45
import xarray as xr
56
from sklearn.decomposition import PCA as PCA_
67
from sklearn.decomposition import IncrementalPCA as IncrementalPCA_
78
from sklearn.preprocessing import StandardScaler
8-
import cartopy.crs as ccrs
99

1010
from ..core.decorators import validate_data_pca
1111
from ._base_datamining import BaseReduction
@@ -366,6 +366,7 @@ def _preprocess_data(
366366
standarized_stacked_data_matrix, scaler = self.standarize(
367367
data=stacked_data_matrix,
368368
scaler=self.scaler if not is_fit else None,
369+
transform=not is_fit,
369370
)
370371
else:
371372
self.logger.warning("Data is not standarized")

bluemath_tk/deeplearning/generators/ncDataGenerator.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1+
import keras.utils
12
import numpy as np
23
import xarray as xr
3-
import keras.utils
44

55

66
class DataGenerator(keras.utils.Sequence):
@@ -49,7 +49,6 @@ def counter_reset(self):
4949
self.counter = 0
5050

5151
def __getitem__(self, idx):
52-
5352
# prepare the resulting array
5453
inputs = np.zeros((self.batch_size, 64, 64, 1))
5554
outputs = np.zeros((self.batch_size, 64, 64, 1))
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
import torch
2+
from torch import nn
3+
from torch.utils.data import DataLoader
4+
from torchvision import datasets
5+
from torchvision.transforms import ToTensor
6+
7+
# Download training data from open datasets.
8+
training_data = datasets.FashionMNIST(
9+
root="data",
10+
train=True,
11+
download=True,
12+
transform=ToTensor(),
13+
)
14+
15+
# Download test data from open datasets.
16+
test_data = datasets.FashionMNIST(
17+
root="data",
18+
train=False,
19+
download=True,
20+
transform=ToTensor(),
21+
)
22+
23+
batch_size = 64
24+
25+
# Create data loaders.
26+
train_dataloader = DataLoader(training_data, batch_size=batch_size)
27+
test_dataloader = DataLoader(test_data, batch_size=batch_size)
28+
29+
for X, y in test_dataloader:
30+
print(f"Shape of X [N, C, H, W]: {X.shape}")
31+
print(f"Shape of y: {y.shape} {y.dtype}")
32+
break
33+
34+
device = (
35+
torch.accelerator.current_accelerator().type
36+
if torch.accelerator.is_available()
37+
else "cpu"
38+
)
39+
print(f"Using {device} device")
40+
41+
42+
# Define model
43+
class NeuralNetwork(nn.Module):
44+
def __init__(self):
45+
super().__init__()
46+
self.flatten = nn.Flatten()
47+
self.linear_relu_stack = nn.Sequential(
48+
nn.Linear(28 * 28, 512),
49+
nn.ReLU(),
50+
nn.Linear(512, 512),
51+
nn.ReLU(),
52+
nn.Linear(512, 10),
53+
)
54+
55+
def forward(self, x):
56+
x = self.flatten(x)
57+
logits = self.linear_relu_stack(x)
58+
return logits
59+
60+
61+
model = NeuralNetwork().to(device)
62+
print(model)
63+
64+
loss_fn = nn.CrossEntropyLoss()
65+
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
66+
67+
68+
def train(dataloader, model, loss_fn, optimizer):
69+
size = len(dataloader.dataset)
70+
model.train()
71+
for batch, (X, y) in enumerate(dataloader):
72+
X, y = X.to(device), y.to(device)
73+
74+
# Compute prediction error
75+
pred = model(X)
76+
loss = loss_fn(pred, y)
77+
78+
# Backpropagation
79+
loss.backward()
80+
optimizer.step()
81+
optimizer.zero_grad()
82+
83+
if batch % 100 == 0:
84+
loss, current = loss.item(), (batch + 1) * len(X)
85+
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
86+
87+
88+
def test(dataloader, model, loss_fn):
89+
size = len(dataloader.dataset)
90+
num_batches = len(dataloader)
91+
model.eval()
92+
test_loss, correct = 0, 0
93+
with torch.no_grad():
94+
for X, y in dataloader:
95+
X, y = X.to(device), y.to(device)
96+
pred = model(X)
97+
test_loss += loss_fn(pred, y).item()
98+
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
99+
test_loss /= num_batches
100+
correct /= size
101+
print(
102+
f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n"
103+
)
104+
105+
106+
epochs = 5
107+
for t in range(epochs):
108+
print(f"Epoch {t + 1}\n-------------------------------")
109+
train(train_dataloader, model, loss_fn, optimizer)
110+
test(test_dataloader, model, loss_fn)
111+
print("Done!")

0 commit comments

Comments
 (0)