diff --git a/AUTHORS.md b/AUTHORS.md index ea00192d..6e935a78 100644 --- a/AUTHORS.md +++ b/AUTHORS.md @@ -1,3 +1,5 @@ Daniel J. Hofmann https://github.com/daniel-j-h Bhargav Kowshik https://github.com/bkowshik + +Olivier Courtin https://github.com/ocourtin diff --git a/README.md b/README.md index b70ed2b8..b5fd0853 100644 --- a/README.md +++ b/README.md @@ -87,7 +87,7 @@ The following describes the installation from scratch. - Install native system dependencies required for Python 3 bindings ```bash -apt-get install build-essential libboost-python-dev libexpat1-dev zlib1g-dev libbz2-dev libspatialindex-dev +apt-get install build-essential libboost-python-dev libexpat1-dev zlib1g-dev libbz2-dev libspatialindex-dev libjpeg-turbo8-dev libwebp-dev ``` - Use a virtualenv for installing this project locally diff --git a/config/config.toml b/config/config.toml new file mode 100644 index 00000000..8546b1ca --- /dev/null +++ b/config/config.toml @@ -0,0 +1,55 @@ +# RoboSat Configuration +# For syntax see: https://github.com/toml-lang/toml#table-of-contents + +[dataset] + # The slippy map dataset's base directory. + path = '/tmp/slippy-map-dir/' + + # Dataset specific class weights computes on the training data. + # Needed by 'mIoU' and 'CrossEntropy' losses to deal with unbalanced classes. + # Note: use `./rs weights -h` to compute these for new datasets. + weights = [1.6248, 5.762827] + + +[classes] + # Human representation for classes. + titles = ['background', 'parking'] + + # Color map for visualization and representing classes in masks. + # Note: available colors are either CSS3 colors names or #RRGGBB hexadecimal representation. + colors = ['denim', 'orange'] + + +# Channels configuration let your indicate wich dataset sub-directory and bands to take as input +# You could so, add several channels blocks to compose your input Tensor. Orders are meaningful. +[[channels]] +sub = "images" +bands = [1,2,3] + + +# Model specific attributes. +[model] + + # Batch size for training. + batch_size = 2 + + # Image side size in pixels. + image_size = 512 + + # Total number of epochs to train for. + epochs = 10 + + # Learning rate for the optimizer. + lr = 0.0001 + + # Weight decay l2 penalty for the optimizer + decay = 0.0001 + + # Loss function name (e.g 'Lovasz', 'mIoU' or 'CrossEntropy') + loss = 'Lovasz' + + # Data augmentation, Flip or Rotate probability + data_augmentation = 0.75 + + # Use ImageNet weights pretraining + pretrained = true diff --git a/config/dataset-parking.toml b/config/dataset-parking.toml deleted file mode 100644 index d45d545a..00000000 --- a/config/dataset-parking.toml +++ /dev/null @@ -1,23 +0,0 @@ -# Configuration related to a specific dataset. -# For syntax see: https://github.com/toml-lang/toml#table-of-contents - - -# Dataset specific common attributes. -[common] - - # The slippy map dataset's base directory. - dataset = '/tmp/slippy-map-dir/' - - # Human representation for classes. - classes = ['background', 'parking'] - - # Color map for visualization and representing classes in masks. - # Note: available colors can be found in `robosat/colors.py` - colors = ['denim', 'orange'] - - -# Dataset specific class weights computes on the training data. -# Needed by 'mIoU' and 'CrossEntropy' losses to deal with unbalanced classes. -# Note: use `./rs weights -h` to compute these for new datasets. -[weights] - values = [1.6248, 5.762827] diff --git a/config/model-unet.toml b/config/model-unet.toml deleted file mode 100644 index 6effd699..00000000 --- a/config/model-unet.toml +++ /dev/null @@ -1,34 +0,0 @@ -# Configuration related to a specific model. -# For syntax see: https://github.com/toml-lang/toml#table-of-contents - - -# Model specific common attributes. -[common] - - # Use CUDA for GPU acceleration. - cuda = true - - # Batch size for training. - batch_size = 2 - - # Image side size in pixels. - image_size = 512 - - # Directory where to save checkpoints to during training. - checkpoint = '/tmp/pth/' - - -# Model specific optimization parameters. -[opt] - - # Total number of epochs to train for. - epochs = 10 - - # Learning rate for the optimizer. - lr = 0.0001 - - # Weight decay l2 penalty for the optimizer - decay = 0.0001 - - # Loss function name (e.g 'Lovasz', 'mIoU' or 'CrossEntropy') - loss = 'Lovasz' diff --git a/deps/requirements-lock.txt b/deps/requirements-lock.txt index 208ffdc1..f0aa911a 100644 --- a/deps/requirements-lock.txt +++ b/deps/requirements-lock.txt @@ -28,7 +28,7 @@ pyparsing==2.2.2 pyproj==1.9.5.1 pytest==3.9.1 python-dateutil==2.7.3 -rasterio==1.0.8 +rasterio==1.0.10 requests==2.20.0 Rtree==0.8.3 scikit-learn==0.20.0 @@ -43,3 +43,4 @@ torchvision==0.2.1 tqdm==4.27.0 urllib3==1.24 Werkzeug==0.14.1 +webcolors==1.8.1 diff --git a/deps/requirements.txt b/deps/requirements.txt index 1c403232..19a5abe7 100644 --- a/deps/requirements.txt +++ b/deps/requirements.txt @@ -1,6 +1,6 @@ torchvision numpy -pillow +pillow-simd scipy opencv-contrib-python tqdm @@ -17,3 +17,4 @@ rtree pyproj toml pytest +webcolors diff --git a/docker/Dockerfile.cpu b/docker/Dockerfile.cpu index d824fd5f..01237605 100644 --- a/docker/Dockerfile.cpu +++ b/docker/Dockerfile.cpu @@ -6,7 +6,7 @@ FROM ubuntu:16.04 # See: https://github.com/skvark/opencv-python/issues/90 RUN apt-get update -qq && \ apt-get install -qq -y -o quiet=1 \ - python3 python3-dev python3-tk python3-pip build-essential libboost-python-dev libexpat1-dev zlib1g-dev libbz2-dev libspatialindex-dev libsm6 + python3 python3-dev python3-tk python3-pip build-essential libboost-python-dev libexpat1-dev zlib1g-dev libbz2-dev libspatialindex-dev libsm6 libwebp-dev libjpeg-turbo8-dev WORKDIR /app ADD . /app diff --git a/docker/Dockerfile.gpu b/docker/Dockerfile.gpu index 776ef8b3..02b3baa5 100644 --- a/docker/Dockerfile.gpu +++ b/docker/Dockerfile.gpu @@ -6,7 +6,7 @@ FROM nvidia/cuda:9.1-cudnn7-runtime-ubuntu16.04 # See: https://github.com/skvark/opencv-python/issues/90 RUN apt-get update -qq && \ apt-get install -qq -y -o quiet=1 \ - python3 python3-dev python3-tk python3-pip build-essential libboost-python-dev libexpat1-dev zlib1g-dev libbz2-dev libspatialindex-dev libsm6 + python3 python3-dev python3-tk python3-pip build-essential libboost-python-dev libexpat1-dev zlib1g-dev libbz2-dev libspatialindex-dev libsm6 libwebp-dev libjpeg-turbo8-dev WORKDIR /app ADD . /app diff --git a/robosat/colors.py b/robosat/colors.py index 50a91dd0..3687e7b8 100644 --- a/robosat/colors.py +++ b/robosat/colors.py @@ -2,76 +2,30 @@ """ import colorsys - -from enum import Enum, unique - - -# Todo: user should be able to bring her own color palette. -# Functions need to account for that and not use one palette. - - -def _rgb(v): - r, g, b = v[1:3], v[3:5], v[5:7] - return int(r, 16), int(g, 16), int(b, 16) - - -@unique -class Mapbox(Enum): - """Mapbox-themed colors. - - See: https://www.mapbox.com/base/styling/color/ - """ - - dark = _rgb("#404040") - gray = _rgb("#eeeeee") - light = _rgb("#f8f8f8") - white = _rgb("#ffffff") - cyan = _rgb("#3bb2d0") - blue = _rgb("#3887be") - bluedark = _rgb("#223b53") - denim = _rgb("#50667f") - navy = _rgb("#28353d") - navydark = _rgb("#222b30") - purple = _rgb("#8a8acb") - teal = _rgb("#41afa5") - green = _rgb("#56b881") - yellow = _rgb("#f1f075") - mustard = _rgb("#fbb03b") - orange = _rgb("#f9886c") - red = _rgb("#e55e5e") - pink = _rgb("#ed6498") +import webcolors +import numpy as np def make_palette(*colors): - """Builds a PIL-compatible color palette from color names. + """Builds a PIL-compatible color palette from CSS3 color names, or hex values patterns as #RRGGBB Args: colors: variable number of color names. """ - rgbs = [Mapbox[color].value for color in colors] - flattened = sum(rgbs, ()) - return list(flattened) - - -def color_string_to_rgb(color): - """Convert color string to a list of RBG integers. + assert 0 < len(colors) <= 256 - Args: - color: the string color value for example "250,0,0" - - Returns: - color: as a list of RGB integers for example [250,0,0] - """ + hexs = [webcolors.CSS3_NAMES_TO_HEX[color] if color[0] != "#" else color for color in colors] + rgbs = [(int(h[1:3], 16), int(h[3:5], 16), int(h[5:7], 16)) for h in hexs] - return [*map(int, color.split(","))] + return list(sum(rgbs, ())) def continuous_palette_for_color(color, bins=256): """Creates a continuous color palette based on a single color. Args: - color: the rgb color tuple to create a continuous palette for. + color: the CSS3 color name or it's hex values as #RRGGBB, to create a continuous palette for. bins: the number of colors to create in the continuous palette. Returns: @@ -81,15 +35,29 @@ def continuous_palette_for_color(color, bins=256): # A quick and dirty way to create a continuous color palette is to convert from the RGB color # space into the HSV color space and then only adapt the color's saturation (S component). - r, g, b = [v / 255 for v in Mapbox[color].value] + hexs = webcolors.CSS3_NAMES_TO_HEX[color] if color[0] != "#" else color + r, g, b = [(int(h[1:3], 16), int(h[3:5], 16), int(h[5:7], 16)) for h in hexs] h, s, v = colorsys.rgb_to_hsv(r, g, b) - palette = [] + assert 0 < bins <= 256 + palette = [] for i in range(bins): - ns = (1 / bins) * (i + 1) - palette.extend([int(v * 255) for v in colorsys.hsv_to_rgb(h, ns, v)]) - - assert len(palette) // 3 == bins + r, g, b = [int(v * 255) for v in colorsys.hsv_to_rgb(h, (1 / bins) * (i + 1), v)] + palette.extend(r, g, b) return palette + + +def complementary_palette(palette): + """Creates a PIL complementary colors palette based on an initial PIL palette""" + + comp_palette = [] + colors = [palette[i : i + 3] for i in range(0, len(palette), 3)] + + for color in colors: + r, g, b = [v for v in color] + h, s, v = colorsys.rgb_to_hsv(r, g, b) + comp_palette.extend(map(int, colorsys.hsv_to_rgb((h + 0.5) % 1, s, v))) + + return comp_palette diff --git a/robosat/datasets.py b/robosat/datasets.py index 80659679..1c7d22a2 100644 --- a/robosat/datasets.py +++ b/robosat/datasets.py @@ -5,9 +5,13 @@ See: http://pytorch.org/docs/0.3.1/data.html """ +import os +import sys import torch from PIL import Image import torch.utils.data +import cv2 +import numpy as np from robosat.tiles import tiles_from_slippy_map, buffer_tile_image @@ -17,7 +21,7 @@ class SlippyMapTiles(torch.utils.data.Dataset): """Dataset for images stored in slippy map format. """ - def __init__(self, root, transform=None): + def __init__(self, root, mode, transform=None): super().__init__() self.tiles = [] @@ -25,13 +29,27 @@ def __init__(self, root, transform=None): self.tiles = [(tile, path) for tile, path in tiles_from_slippy_map(root)] self.tiles.sort(key=lambda tile: tile[0]) + self.mode = mode def __len__(self): return len(self.tiles) def __getitem__(self, i): tile, path = self.tiles[i] - image = Image.open(path) + + if self.mode == "image": + image = cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB) + + elif self.mode == "multibands": + image = cv2.imread(path, cv2.IMREAD_ANYCOLOR) + if len(image.shape) == 3 and image.shape[2] >= 3: + # FIXME Look twice to find an in-place way to perform a multiband BGR2RGB + g = image[:, :, 0] + image[:, :, 0] = image[:, :, 2] + image[:, :, 2] = g + + elif self.mode == "mask": + image = np.array(Image.open(path).convert("P")) if self.transform is not None: image = self.transform(image) @@ -40,42 +58,49 @@ def __getitem__(self, i): # Multiple Slippy Map directories. -# Think: one with images, one with masks, one with rasterized traces. class SlippyMapTilesConcatenation(torch.utils.data.Dataset): """Dataset to concate multiple input images stored in slippy map format. """ - def __init__(self, inputs, target, joint_transform=None): + def __init__(self, path, channels, target, joint_transform=None): super().__init__() - # No transformations in the `SlippyMapTiles` instead joint transformations in getitem - self.joint_transform = joint_transform + assert len(channels), "Channels configuration empty" + self.channels = channels + self.inputs = dict() + + for channel in channels: + for band in channel["bands"]: + self.inputs[channel["sub"]] = SlippyMapTiles(os.path.join(path, channel["sub"]), mode="multibands") - self.inputs = [SlippyMapTiles(inp) for inp in inputs] - self.target = SlippyMapTiles(target) + self.target = SlippyMapTiles(target, mode="mask") - assert len(set([len(dataset) for dataset in self.inputs])) == 1, "same number of tiles in all images" - assert len(self.target) == len(self.inputs[0]), "same number of tiles in images and label" + # No transformations in the `SlippyMapTiles` instead joint transformations in getitem + self.joint_transform = joint_transform def __len__(self): return len(self.target) def __getitem__(self, i): - # at this point all transformations are applied and we expect to work with raw tensors - inputs = [dataset[i] for dataset in self.inputs] - images = [image for image, _ in inputs] - tiles = [tile for _, tile in inputs] + mask, tile = self.target[i] - mask, mask_tile = self.target[i] + for channel in self.channels: + try: + data, band_tile = self.inputs[channel["sub"]][i] + assert band_tile == tile - assert len(set(tiles)) == 1, "all images are for the same tile" - assert tiles[0] == mask_tile, "image tile is the same as label tile" + for band in channel["bands"]: + data_band = data[:, :, int(band) - 1] if len(data.shape) == 3 else data_band + data_band = data_band.reshape(mask.shape[0], mask.shape[1], 1) + tensor = np.concatenate((tensor, data_band), axis=2) if "tensor" in locals() else data_band + except: + sys.exit("Unable to concatenate input Tensor") if self.joint_transform is not None: - images, mask = self.joint_transform(images, mask) + tensor, mask = self.joint_transform(tensor, mask) - return torch.cat(images, dim=0), mask, tiles + return tensor, mask, tile # Todo: once we have the SlippyMapDataset this dataset should wrap @@ -113,7 +138,7 @@ def __len__(self): def __getitem__(self, i): tile, path = self.tiles[i] - image = buffer_tile_image(tile, self.tiles, overlap=self.overlap, tile_size=self.size) + image = np.array(buffer_tile_image(tile, self.tiles, overlap=self.overlap, tile_size=self.size)) if self.transform is not None: image = self.transform(image) diff --git a/robosat/log.py b/robosat/log.py index 2e7f9ff1..7616ed08 100644 --- a/robosat/log.py +++ b/robosat/log.py @@ -11,17 +11,27 @@ class Log: """ def __init__(self, path, out=sys.stdout): + + self.fp = None self.out = out - self.fp = open(path, "a") - assert self.fp, "Unable to open log file" + try: + if path: + if not os.path.isdir(os.path.dirname(path)): + os.makedirs(os.path.dirname(path), exist_ok=True) + self.fp = open(path, mode="a") + except: + sys.exit("Unable to write in log directory") """Log a new message to the opened log file, and optionnaly on stdout or stderr too """ def log(self, msg): - assert self.fp, "Unable to write in log file" - self.fp.write(msg + os.linesep) - self.fp.flush() - - if self.out: - print(msg, file=self.out) + try: + if self.fp: + self.fp.write(msg + os.linesep) + self.fp.flush() + + if self.out: + print(msg, file=self.out) + except: + sys.exit("Unable to write in log file") diff --git a/robosat/metrics.py b/robosat/metrics.py index 5d125e2a..333de523 100644 --- a/robosat/metrics.py +++ b/robosat/metrics.py @@ -24,16 +24,19 @@ def __init__(self, labels): self.fp = 0 self.tp = 0 - def add(self, actual, predicted): + def add(self, label, predicted, is_prob=True): """Adds an observation to the tracker. Args: - actual: the ground truth labels. - predicted: the predicted labels. + label: the ground truth labels. + predicted: the predicted prob or mask. + is_prob: as predicted could be either a prob or a mask. """ - masks = torch.argmax(predicted, 0) - confusion = masks.view(-1).float() / actual.view(-1).float() + if is_prob: + predicted = torch.argmax(predicted, 0) + + confusion = predicted.view(-1).float() / label.view(-1).float() self.tn += torch.sum(torch.isnan(confusion)).item() self.fn += torch.sum(confusion == float("inf")).item() @@ -46,7 +49,13 @@ def get_miou(self): Returns: The mean Intersection over Union score for all observations seen so far. """ - return np.nanmean([self.tn / (self.tn + self.fn + self.fp), self.tp / (self.tp + self.fn + self.fp)]) + + try: + miou = np.nanmean([self.tn / (self.tn + self.fn + self.fp), self.tp / (self.tp + self.fn + self.fp)]) + except ZeroDivisionError: + miou = float("NaN") + + return miou def get_fg_iou(self): """Retrieves the foreground Intersection over Union score. @@ -58,7 +67,7 @@ def get_fg_iou(self): try: iou = self.tp / (self.tp + self.fn + self.fp) except ZeroDivisionError: - iou = float("Inf") + iou = float("NaN") return iou @@ -74,7 +83,7 @@ def get_mcc(self): (self.tp + self.fp) * (self.tp + self.fn) * (self.tn + self.fp) * (self.tn + self.fn) ) except ZeroDivisionError: - mcc = float("Inf") + mcc = float("NaN") return mcc diff --git a/robosat/tiles.py b/robosat/tiles.py index 24a09a1e..a1895b88 100644 --- a/robosat/tiles.py +++ b/robosat/tiles.py @@ -11,8 +11,14 @@ import csv import io import os +from glob import glob +import cv2 from PIL import Image +import numpy as np + +from rasterio.warp import transform +from rasterio.crs import CRS import mercantile @@ -81,6 +87,7 @@ def isdigit(v): except ValueError: return False + root = os.path.expanduser(root) for z in os.listdir(root): if not isdigit(z): continue @@ -110,6 +117,7 @@ def tiles_from_csv(path): The mercantile tiles from the csv file. """ + path = os.path.expanduser(path) with open(path) as fp: reader = csv.reader(fp) @@ -120,27 +128,25 @@ def tiles_from_csv(path): yield mercantile.Tile(*map(int, row)) -def stitch_image(into, into_box, image, image_box): - """Stitches two images together in-place. - - Args: - into: the image to stitch into and modify in-place. - into_box: left, upper, right, lower image coordinates for where to place `image` in `into`. - image: the image to stitch into `into`. - image_box: left, upper, right, lower image coordinates for where to extract the sub-image from `image`. +def tile_image(root, x, y, z): + """Retrieves H,W,C numpy array, from a tile store and X,Y,Z coordinates, or `None`""" - Note: - Both boxes must be of same size. - """ + try: + root = os.path.expanduser(root) + path = glob(os.path.join(root, z, x, y) + "*") + assert len(path) == 1 + img = np.array(Image.open(path[0]).convert("RGB")) + except: + return None - into.paste(image.crop(box=image_box), box=into_box) + return img -def adjacent_tile(tile, dx, dy, tiles): - """Retrieves an adjacent tile from a tile store. +def adjacent_tile_image(tile, dx, dy, tiles): + """Retrieves an adjacent tile image from a tile store. Args: - tile: the original tile to get an adjacent tile for. + tile: the original tile to get an adjacent tile image for. dx: the offset in tile x direction. dy: the offset in tile y direction. tiles: the tile store to get tiles from; must support `__getitem__` with tiles. @@ -150,16 +156,17 @@ def adjacent_tile(tile, dx, dy, tiles): """ x, y, z = map(int, [tile.x, tile.y, tile.z]) - other = mercantile.Tile(x=x + dx, y=y + dy, z=z) + adjacent = mercantile.Tile(x=x + dx, y=y + dy, z=z) try: - path = tiles[other] - return Image.open(path).convert("RGB") + path = tiles[adjacent] except KeyError: return None + return cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB) + -def buffer_tile_image(tile, tiles, overlap, tile_size, nodata=0): +def buffer_tile_image(tile, tiles, overlap, tile_size): """Buffers a tile image adding borders on all sides based on adjacent tiles. Args: @@ -167,61 +174,45 @@ def buffer_tile_image(tile, tiles, overlap, tile_size, nodata=0): tiles: available tiles; must be a mapping of tiles to their filesystem paths. overlap: the tile border to add on every side; in pixel. tile_size: the tile size. - nodata: the color value to use when no adjacent tile is available. Returns: - The composite image containing the original tile plus tile overlap on all sides. + The H,W,C numpy composite image containing the original tile plus tile overlap on all sides. It's size is `tile_size` + 2 * `overlap` pixel for each side. """ + assert 0 <= overlap <= tile_size, "Overlap value can't be either negative or bigger than tile_size" + tiles = dict(tiles) x, y, z = map(int, [tile.x, tile.y, tile.z]) + # 3x3 matrix (upper, center, bottom) x (left, center, right) + ul = adjacent_tile_image(tile, -1, -1, tiles) + uc = adjacent_tile_image(tile, +0, -1, tiles) + ur = adjacent_tile_image(tile, +1, -1, tiles) + cl = adjacent_tile_image(tile, -1, +0, tiles) + cc = adjacent_tile_image(tile, +0, +0, tiles) + cr = adjacent_tile_image(tile, +1, +0, tiles) + bl = adjacent_tile_image(tile, -1, +1, tiles) + bc = adjacent_tile_image(tile, +0, +1, tiles) + br = adjacent_tile_image(tile, +1, +1, tiles) + + ts = tile_size + o = overlap + oo = overlap * 2 + # Todo: instead of nodata we should probably mirror the center image - composite_size = tile_size + 2 * overlap - composite = Image.new(mode="RGB", size=(composite_size, composite_size), color=nodata) - - path = tiles[tile] - center = Image.open(path).convert("RGB") - composite.paste(center, box=(overlap, overlap)) - - top_left = adjacent_tile(tile, -1, -1, tiles) - top_right = adjacent_tile(tile, +1, -1, tiles) - bottom_left = adjacent_tile(tile, -1, +1, tiles) - bottom_right = adjacent_tile(tile, +1, +1, tiles) - - top = adjacent_tile(tile, 0, -1, tiles) - left = adjacent_tile(tile, -1, 0, tiles) - bottom = adjacent_tile(tile, 0, +1, tiles) - right = adjacent_tile(tile, +1, 0, tiles) - - def maybe_stitch(maybe_tile, composite_box, tile_box): - if maybe_tile: - stitch_image(composite, composite_box, maybe_tile, tile_box) - - maybe_stitch(top_left, (0, 0, overlap, overlap), (tile_size - overlap, tile_size - overlap, tile_size, tile_size)) - maybe_stitch( - top_right, (tile_size + overlap, 0, composite_size, overlap), (0, tile_size - overlap, overlap, tile_size) - ) - maybe_stitch( - bottom_left, - (0, composite_size - overlap, overlap, composite_size), - (tile_size - overlap, 0, tile_size, overlap), - ) - maybe_stitch( - bottom_right, - (composite_size - overlap, composite_size - overlap, composite_size, composite_size), - (0, 0, overlap, overlap), - ) - maybe_stitch(top, (overlap, 0, composite_size - overlap, overlap), (0, tile_size - overlap, tile_size, tile_size)) - maybe_stitch(left, (0, overlap, overlap, composite_size - overlap), (tile_size - overlap, 0, tile_size, tile_size)) - maybe_stitch( - bottom, - (overlap, composite_size - overlap, composite_size - overlap, composite_size), - (0, 0, tile_size, overlap), - ) - maybe_stitch( - right, (composite_size - overlap, overlap, composite_size, composite_size - overlap), (0, 0, overlap, tile_size) - ) - - return composite + img = np.zeros((ts + oo, ts + oo, 3)).astype(np.uint8) + + # fmt:off + img[0:o, 0:o, :] = ul[-o:ts, -o:ts, :] if ul is not None else np.zeros((o, o, 3)).astype(np.uint8) + img[0:o, o:ts+o, :] = uc[-o:ts, 0:ts, :] if uc is not None else np.zeros((o, ts, 3)).astype(np.uint8) + img[0:o, ts+o:ts+oo, :] = ur[-o:ts, 0:o, :] if ur is not None else np.zeros((o, o, 3)).astype(np.uint8) + img[o:ts+o, 0:o, :] = cl[0:ts, -o:ts, :] if cl is not None else np.zeros((ts, o, 3)).astype(np.uint8) + img[o:ts+o, o:ts+o, :] = cc if cc is not None else np.zeros((ts, ts, 3)).astype(np.uint8) + img[o:ts+o, ts+o:ts+oo, :] = cr[0:ts, 0:o, :] if cr is not None else np.zeros((ts, o, 3)).astype(np.uint8) + img[ts+o:ts+oo, 0:o, :] = bl[0:o, -o:ts, :] if bl is not None else np.zeros((o, o, 3)).astype(np.uint8) + img[ts+o:ts+oo, o:ts+o, :] = bc[0:o, 0:ts, :] if bc is not None else np.zeros((o, ts, 3)).astype(np.uint8) + img[ts+o:ts+oo, ts+o:ts+oo, :] = br[0:o, 0:o, :] if br is not None else np.zeros((o, o, 3)).astype(np.uint8) + # fmt:on + + return img diff --git a/robosat/tools/__main__.py b/robosat/tools/__main__.py index a4cf2f0b..ac63459a 100644 --- a/robosat/tools/__main__.py +++ b/robosat/tools/__main__.py @@ -16,6 +16,7 @@ rasterize, serve, subset, + tile, train, weights, ) @@ -27,6 +28,7 @@ def add_parsers(): # Add your tool's entry point below. + tile.add_parser(subparser) extract.add_parser(subparser) cover.add_parser(subparser) download.add_parser(subparser) diff --git a/robosat/tools/compare.py b/robosat/tools/compare.py index 956fc969..4b2f3c05 100644 --- a/robosat/tools/compare.py +++ b/robosat/tools/compare.py @@ -1,67 +1,166 @@ import os +import sys +import math +import json +import torch import argparse from PIL import Image from tqdm import tqdm import numpy as np -from robosat.tiles import tiles_from_slippy_map +from mercantile import feature + +from robosat.colors import make_palette, complementary_palette +from robosat.tiles import tiles_from_slippy_map, tile_image +from robosat.config import load_config +from robosat.metrics import Metrics +from robosat.utils import web_ui +from robosat.log import Log def add_parser(subparser): parser = subparser.add_parser( - "compare", - help="compare images, labels and masks side by side", - formatter_class=argparse.ArgumentDefaultsHelpFormatter, + "compare", help="compare images and/or labels and masks", formatter_class=argparse.ArgumentDefaultsHelpFormatter ) - parser.add_argument("out", type=str, help="directory to save visualizations to") - parser.add_argument("images", type=str, help="directory to read slippy map images from") - parser.add_argument("labels", type=str, help="directory to read slippy map labels from") - parser.add_argument("masks", type=str, nargs="+", help="slippy map directories to read masks from") - parser.add_argument("--minimum", type=float, default=0.0, help="minimum percentage of mask not background") - parser.add_argument("--maximum", type=float, default=1.0, help="maximum percentage of mask not background") + parser.add_argument("--mode", type=str, default="side", choices=["side", "stack", "list"], help="compare mode") + parser.add_argument("--images", type=str, nargs="+", help="slippy map images dirs to render (stack or side mode)") + parser.add_argument("--ext", type=str, default="webp", help="file format to save images in (stack or side mode)") + parser.add_argument("--labels", type=str, help="directory to read slippy map labels from (needed for QoD metric)") + parser.add_argument("--masks", type=str, help="directory to read slippy map masks from (needed for QoD metric)") + parser.add_argument("--config", type=str, help="path to configuration file (needed for QoD metric)") + parser.add_argument("--minimum_fg", type=float, default=0.0, help="skip tile if label foreground below, [0-100]") + parser.add_argument("--maximum_fg", type=float, default=100.0, help="skip tile if label foreground above, [0-100]") + parser.add_argument("--minimum_qod", type=float, default=0.0, help="skip tile if QoD metric below, [0-100]") + parser.add_argument("--maximum_qod", type=float, default=100.0, help="skip tile if QoD metric above, [0-100]") + parser.add_argument("--vertical", action="store_true", help="render vertical image aggregate, for side mode") + parser.add_argument("--geojson", action="store_true", help="output geojson based, for list mode") + parser.add_argument("--web_ui", type=str, help="web ui base url") + parser.add_argument("--web_ui_template", type=str, help="path to an alternate web ui template") + parser.add_argument("out", type=str, help="directory or path (upon mode) to save output to") parser.set_defaults(func=main) -def main(args): - images = tiles_from_slippy_map(args.images) +def compare(masks, labels, tile, classes): - for tile, path in tqdm(list(images), desc="Compare", unit="image", ascii=True): - x, y, z = list(map(str, tile)) + x, y, z = list(map(str, tile)) + label = np.array(Image.open(os.path.join(labels, z, x, "{}.png".format(y)))) + mask = np.array(Image.open(os.path.join(masks, z, x, "{}.png".format(y)))) - image = Image.open(path).convert("RGB") - label = Image.open(os.path.join(args.labels, z, x, "{}.png".format(y))).convert("P") - assert image.size == label.size + assert label.shape == mask.shape + assert len(label.shape) == 2 and len(classes) == 2 # Still binary centric - keep = False - masks = [] - for path in args.masks: - mask = Image.open(os.path.join(path, z, x, "{}.png".format(y))).convert("P") - assert image.size == mask.size - masks.append(mask) + metrics = Metrics(classes) + metrics.add(torch.from_numpy(label), torch.from_numpy(mask), is_prob=False) + fg_iou = metrics.get_fg_iou() - # TODO: The calculation below does not work for multi-class. - percentage = np.sum(np.array(mask) != 0) / np.prod(image.size) + fg_ratio = 100 * max(np.sum(mask != 0), np.sum(label != 0)) / mask.size + dist = 0.0 if math.isnan(fg_iou) else 1.0 - fg_iou - # Keep this image when percentage is within required threshold. - if percentage >= args.minimum and percentage <= args.maximum: - keep = True + qod = 100 - (dist * (math.log(fg_ratio + 1.0) + np.finfo(float).eps) * (100 / math.log(100))) + qod = 0.0 if qod < 0.0 else qod # Corner case prophilaxy - if not keep: - continue + return dist, fg_ratio, qod - width, height = image.size - # Columns for image, label and all the masks. - columns = 2 + len(masks) - combined = Image.new(mode="RGB", size=(columns * width, height)) +def main(args): + + if not args.masks or not args.labels or not args.config: + if args.mode == "list": + sys.exit("Parameters masks, labels and config, are all mandatories in list mode.") + if args.minimum_fg > 0 or args.maximum_fg < 100 or args.minimum_qod > 0 or args.maximum_qod < 100: + sys.exit("Parameters masks, labels and config, are all mandatories in QoD filtering.") + + if args.images: + tiles = [tile for tile, _ in tiles_from_slippy_map(args.images[0])] + for image in args.images[1:]: + assert sorted(tiles) == sorted([tile for tile, _ in tiles_from_slippy_map(image)]), "inconsistent coverages" + + if args.labels and args.masks: + tiles_masks = [tile for tile, _ in tiles_from_slippy_map(args.masks)] + tiles_labels = [tile for tile, _ in tiles_from_slippy_map(args.labels)] + if args.images: + assert sorted(tiles) == sorted(tiles_masks) == sorted(tiles_labels), "inconsistent coverages" + else: + assert sorted(tiles_masks) == sorted(tiles_labels), "inconsistent coverages" + tiles = tiles_masks + + if args.mode == "list": + out = open(args.out, mode="w") + if args.geojson: + out.write('{"type":"FeatureCollection","features":[') + first = True + + tiles_compare = [] + for tile in tqdm(list(tiles), desc="Compare", unit="tile", ascii=True): - combined.paste(image, box=(0 * width, 0)) - combined.paste(label, box=(1 * width, 0)) - for i, mask in enumerate(masks): - combined.paste(mask, box=((2 + i) * width, 0)) + x, y, z = list(map(str, tile)) - os.makedirs(os.path.join(args.out, z, x), exist_ok=True) - path = os.path.join(args.out, z, x, "{}.png".format(y)) - combined.save(path, optimize=True) + if args.masks and args.labels and args.config: + classes = load_config(args.config)["classes"]["titles"] + dist, fg_ratio, qod = compare(args.masks, args.labels, tile, classes) + if not args.minimum_fg <= fg_ratio <= args.maximum_fg or not args.minimum_qod <= qod <= args.maximum_qod: + continue + + tiles_compare.append(tile) + + if args.mode == "side": + + for i, image in enumerate(args.images): + img = tile_image(image, x, y, z) + + if i == 0: + side = np.zeros((img.shape[0], img.shape[1] * len(args.images), 3)) + side = np.swapaxes(side, 0, 1) if args.vertical else side + image_shape = img.shape + else: + assert image_shape == img.shape, "Unconsistent image size to compare" + + if args.vertical: + side[i * image_shape[0] : (i + 1) * image_shape[0], :, :] = img + else: + side[:, i * image_shape[0] : (i + 1) * image_shape[0], :] = img + + os.makedirs(os.path.join(args.out, z, x), exist_ok=True) + side = Image.fromarray(np.uint8(side)) + side.save(os.path.join(args.out, z, x, "{}.{}".format(y, args.ext)), optimize=True) + + elif args.mode == "stack": + + for i, image in enumerate(args.images): + img = tile_image(image, x, y, z) + + if i == 0: + image_shape = img.shape[0:2] + stack = img / len(args.images) + else: + assert image_shape == img.shape[0:2], "Unconsistent image size to compare" + stack = stack + (img / len(args.images)) + + os.makedirs(os.path.join(args.out, str(z), str(x)), exist_ok=True) + stack = Image.fromarray(np.uint8(stack)) + stack.save(os.path.join(args.out, str(z), str(x), "{}.{}".format(y, args.ext)), optimize=True) + + elif args.mode == "list": + if args.geojson: + prop = '"properties":{{"x":{},"y":{},"z":{},"fg":{:.1f},"qod":{:.1f}}}'.format(x, y, z, fg_ratio, qod) + geom = '"geometry":{}'.format(json.dumps(feature(tile, precision=6)["geometry"])) + out.write('{}{{"type":"Feature",{},{}}}'.format("," if not first else "", geom, prop)) + first = False + else: + out.write("{},{},{}\t\t{:.1f}\t\t{:.1f}{}".format(x, y, z, fg_ratio, qod, os.linesep)) + + if args.mode == "list": + if args.geojson: + out.write("]}") + out.close() + + elif args.mode == "side" and args.web_ui: + template = "compare.html" if not args.web_ui_template else args.web_ui_template + web_ui(args.out, args.web_ui, None, tiles_compare, args.ext, template) + + elif args.mode == "stack" and args.web_ui: + template = "leaflet.html" if not args.web_ui_template else args.web_ui_template + tiles = [tile for tile, _ in tiles_from_slippy_map(args.images[0])] + web_ui(args.out, args.web_ui, tiles, tiles_compare, args.ext, template) diff --git a/robosat/tools/cover.py b/robosat/tools/cover.py index ee5a2931..0c58ea42 100644 --- a/robosat/tools/cover.py +++ b/robosat/tools/cover.py @@ -1,37 +1,57 @@ +import os +import sys import argparse import csv import json -from supermercado import burntiles from tqdm import tqdm +from mercantile import tiles +from supermercado import burntiles + +from robosat.datasets import tiles_from_slippy_map def add_parser(subparser): parser = subparser.add_parser( "cover", - help="generates tiles covering GeoJSON features", + help="generates tiles covering, in csv format: X,Y,Z", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) - parser.add_argument("--zoom", type=int, required=True, help="zoom level of tiles") - parser.add_argument("features", type=str, help="path to GeoJSON features") - parser.add_argument("out", type=str, help="path to csv file to store tiles in") + parser.add_argument("--zoom", type=int, help="zoom level of tiles") + parser.add_argument("--type", type=str, default="geojson", choices=["geojson", "bbox", "dir"], help="input type") + help = "input value, upon type either: a geojson file path, a lat/lon bbox in ESPG:4326, or a slippymap dir path" + parser.add_argument("input", type=str, help=help) + parser.add_argument("out", type=str, help="path to csv file to generate") parser.set_defaults(func=main) def main(args): - with open(args.features) as f: - features = json.load(f) - tiles = [] + if not args.zoom and args.type in ["geojson", "bbox"]: + sys.exit("Zoom parameter is mandatory") + + cover = [] + + if args.type == "geojson": + with open(args.input) as f: + features = json.load(f) + + for feature in tqdm(features["features"], ascii=True, unit="feature"): + cover.extend(map(tuple, burntiles.burn([feature], args.zoom).tolist())) + + cover = list(set(cover)) # tiles can overlap for multiple features; unique tile ids + + elif args.type == "bbox": + west, south, east, north = map(float, args.input.split(",")) + cover = tiles(west, south, east, north, args.zoom) - for feature in tqdm(features["features"], ascii=True, unit="feature"): - tiles.extend(map(tuple, burntiles.burn([feature], args.zoom).tolist())) + elif args.type == "dir": + cover = [tile for tile, _ in tiles_from_slippy_map(args.input)] - # tiles can overlap for multiple features; unique tile ids - tiles = list(set(tiles)) + if not os.path.isdir(os.path.dirname(args.out)): + os.makedirs(os.path.dirname(args.out), exist_ok=True) with open(args.out, "w") as fp: - writer = csv.writer(fp) - writer.writerows(tiles) + csv.writer(fp).writerows(cover) diff --git a/robosat/tools/dedupe.py b/robosat/tools/dedupe.py index ab5d2445..fbd2c942 100644 --- a/robosat/tools/dedupe.py +++ b/robosat/tools/dedupe.py @@ -17,11 +17,11 @@ def add_parser(subparser): formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) - parser.add_argument("osm", type=str, help="ground truth GeoJSON feature collection from OpenStreetMap") - parser.add_argument("predicted", type=str, help="predicted GeoJSON feature collection to deduplicate") parser.add_argument( "--threshold", type=float, required=True, help="maximum allowed IoU to keep predictions, between 0.0 and 1.0" ) + parser.add_argument("osm", type=str, help="ground truth GeoJSON feature collection from OpenStreetMap") + parser.add_argument("predicted", type=str, help="predicted GeoJSON feature collection to deduplicate") parser.add_argument("out", type=str, help="path to GeoJSON to save deduplicated features to") parser.set_defaults(func=main) diff --git a/robosat/tools/download.py b/robosat/tools/download.py index b6ddf5a6..3d146dfe 100644 --- a/robosat/tools/download.py +++ b/robosat/tools/download.py @@ -7,18 +7,27 @@ import requests from PIL import Image from tqdm import tqdm +from mercantile import xy_bounds from robosat.tiles import tiles_from_csv, fetch_image +from robosat.utils import web_ui +from robosat.log import Log def add_parser(subparser): parser = subparser.add_parser( - "download", help="downloads images from Mapbox Maps API", formatter_class=argparse.ArgumentDefaultsHelpFormatter + "download", help="downloads images from a remote server", formatter_class=argparse.ArgumentDefaultsHelpFormatter ) - parser.add_argument("url", type=str, help="endpoint with {z}/{x}/{y} variables to fetch image tiles from") + parser.add_argument( + "url", type=str, help="endpoint with {z}/{x}/{y} or {xmin},{ymin},{xmax},{ymax} variables to fetch image tiles" + ) parser.add_argument("--ext", type=str, default="webp", help="file format to save images in") parser.add_argument("--rate", type=int, default=10, help="rate limit in max. requests per second") + parser.add_argument("--type", type=str, default="XYZ", choices=["XYZ", "WMS", "TMS"], help="service type to use") + parser.add_argument("--timeout", type=int, default=10, help="server request timeout (in seconds)") + parser.add_argument("--web_ui", type=str, help="web ui client base url") + parser.add_argument("--web_ui_template", type=str, help="path to an alternate web ui template") parser.add_argument("tiles", type=str, help="path to .csv tiles file") parser.add_argument("out", type=str, help="path to slippy map directory for storing tiles") @@ -27,10 +36,16 @@ def add_parser(subparser): def main(args): tiles = list(tiles_from_csv(args.tiles)) + already_dl = 0 + dl = 0 with requests.Session() as session: num_workers = args.rate + os.makedirs(os.path.join(args.out), exist_ok=True) + log = Log(os.path.join(args.out, "log"), out=sys.stderr) + log.log("Begin download from {}".format(args.url)) + # tqdm has problems with concurrent.futures.ThreadPoolExecutor; explicitly call `.update` # https://github.com/tqdm/tqdm/issues/97 progress = tqdm(total=len(tiles), ascii=True, unit="image") @@ -46,20 +61,26 @@ def worker(tile): path = os.path.join(args.out, z, x, "{}.{}".format(y, args.ext)) if os.path.isfile(path): - return tile, True - - url = args.url.format(x=tile.x, y=tile.y, z=tile.z) - - res = fetch_image(session, url) - + return tile, None, True + + if args.type == "XYZ": + url = args.url.format(x=tile.x, y=tile.y, z=tile.z) + elif args.type == "TMS": + tile.y = (2 ** tile.z) - tile.y - 1 + url = args.url.format(x=tile.x, y=tile.y, z=tile.z) + elif args.type == "WMS": + xmin, ymin, xmax, ymax = xy_bounds(tile) + url = args.url.format(xmin=xmin, ymin=ymin, xmax=xmax, ymax=ymax) + + res = fetch_image(session, url, args.timeout) if not res: - return tile, False + return tile, url, False try: image = Image.open(res) image.save(path, optimize=True) except OSError: - return tile, False + return tile, url, False tock = time.monotonic() @@ -71,8 +92,21 @@ def worker(tile): progress.update() - return tile, True + return tile, url, True + + for tile, url, ok in executor.map(worker, tiles): + if url and ok: + dl += 1 + elif not url and ok: + already_dl += 1 + else: + log.log("Warning:\n {} failed, skipping.\n {}\n".format(tile, url)) + + if already_dl: + log.log("Notice:\n {} tiles were already downloaded previously, and so skipped now.".format(already_dl)) + if already_dl + dl == len(tiles): + log.log(" Coverage is fully downloaded.") - for tile, ok in executor.map(worker, tiles): - if not ok: - print("Warning: {} failed, skipping".format(tile), file=sys.stderr) + if args.web_ui: + template = "leaflet.html" if not args.web_ui_template else args.web_ui_template + web_ui(args.out, args.web_ui, tiles, tiles, args.ext, template) diff --git a/robosat/tools/export.py b/robosat/tools/export.py index 5a7ee5f3..af85d21e 100644 --- a/robosat/tools/export.py +++ b/robosat/tools/export.py @@ -1,8 +1,10 @@ import argparse +import os import torch import torch.onnx import torch.autograd +import torch.nn as nn from robosat.config import load_config from robosat.unet import UNet @@ -10,31 +12,52 @@ def add_parser(subparser): parser = subparser.add_parser( - "export", help="exports model in ONNX format", formatter_class=argparse.ArgumentDefaultsHelpFormatter + "export", help="exports or prunes model", formatter_class=argparse.ArgumentDefaultsHelpFormatter ) - parser.add_argument("--dataset", type=str, required=True, help="path to dataset configuration file") + parser.add_argument("--config", type=str, required=True, help="path to configuration file") + parser.add_argument("--export_channels", type=int, help="export channels to use (keep the first ones)") + parser.add_argument("--type", type=str, choices=["onnx", "pth"], default="onnx", help="output type") parser.add_argument("--image_size", type=int, default=512, help="image size to use for model") parser.add_argument("--checkpoint", type=str, required=True, help="model checkpoint to load") - parser.add_argument("model", type=str, help="path to save ONNX GraphProto .pb model to") + parser.add_argument("out", type=str, help="path to save export model to") parser.set_defaults(func=main) def main(args): - dataset = load_config(args.dataset) + config = load_config(args.config) - num_classes = len(dataset["common"]["classes"]) - net = UNet(num_classes) + if args.type == "onnx": + os.environ["CUDA_VISIBLE_DEVICES"] = "" + # Workaround: PyTorch ONNX, DataParallel with GPU issue, cf https://github.com/pytorch/pytorch/issues/5315 + + num_classes = len(config["classes"]["titles"]) + num_channels = 0 + for channel in config["channels"]: + num_channels += len(channel["bands"]) + + export_channels = num_channels if not args.export_channels else args.export_channels + assert num_channels >= export_channels, "Will be hard indeed, to export more channels than thoses dataset provide" def map_location(storage, _): return storage.cpu() + net = UNet(num_classes, num_channels=num_channels).to("cpu") chkpt = torch.load(args.checkpoint, map_location=map_location) net = torch.nn.DataParallel(net) net.load_state_dict(chkpt["state_dict"]) - # Todo: make input channels configurable, not hard-coded to three channels for RGB - batch = torch.autograd.Variable(torch.randn(1, 3, args.image_size, args.image_size)) + if export_channels < num_channels: + weights = torch.zeros((64, export_channels, 7, 7)) + weights.data = net.module.resnet.conv1.weight.data[:, :export_channels, :, :] + net.module.resnet.conv1 = nn.Conv2d(num_channels, 64, kernel_size=7, stride=2, padding=3, bias=False) + net.module.resnet.conv1.weight = nn.Parameter(weights) + + if args.type == "onnx": + batch = torch.autograd.Variable(torch.randn(1, export_channels, args.image_size, args.image_size)) + torch.onnx.export(net, batch, args.out) - torch.onnx.export(net, batch, args.model) + elif args.type == "pth": + states = {"epoch": chkpt["epoch"], "state_dict": net.state_dict(), "optimizer": chkpt["optimizer"]} + torch.save(states, args.out) diff --git a/robosat/tools/extract.py b/robosat/tools/extract.py index 036a58cc..356ae471 100644 --- a/robosat/tools/extract.py +++ b/robosat/tools/extract.py @@ -1,12 +1,10 @@ +import os +import sys import argparse -from robosat.osm.parking import ParkingHandler -from robosat.osm.building import BuildingHandler -from robosat.osm.road import RoadHandler - -# Register your osmium handlers here; in addition to the osmium handler interface -# they need to support a `save(path)` function for GeoJSON serialization to a file. -handlers = {"parking": ParkingHandler, "building": BuildingHandler, "road": RoadHandler} +import pkgutil +from pathlib import Path +from importlib import import_module def add_parser(subparser): @@ -16,14 +14,27 @@ def add_parser(subparser): formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) - parser.add_argument("--type", type=str, required=True, choices=handlers.keys(), help="type of feature to extract") - parser.add_argument("map", type=str, help="path to .osm.pbf base map") + parser.add_argument("--type", type=str, required=True, help="type of feature to extract") + parser.add_argument("--path", type=str, help="path to user's extension modules dir") + parser.add_argument("pbf", type=str, help="path to .osm.pbf base map") parser.add_argument("out", type=str, help="path to GeoJSON file to store features in") parser.set_defaults(func=main) def main(args): - handler = handlers[args.type]() - handler.apply_file(filename=args.map, locations=True) + module_search_path = [args.path] if args.path else [] + module_search_path.append(os.path.join(Path(__file__).parent.parent, "osm")) + modules = [(path, name) for path, name, _ in pkgutil.iter_modules(module_search_path) if name != "core"] + if args.type not in [name for _, name in modules]: + sys.exit("Unknown type, thoses available are {}".format([name for _, name in modules])) + + if args.path: + sys.path.append(args.path) + module = import_module(args.type) + else: + module = import_module("robosat.osm.{}".format(args.type)) + + handler = getattr(module, "{}Handler".format(args.type.title()))() + handler.apply_file(filename=args.pbf, locations=True) handler.save(args.out) diff --git a/robosat/tools/features.py b/robosat/tools/features.py index 062441ac..796b59c6 100644 --- a/robosat/tools/features.py +++ b/robosat/tools/features.py @@ -1,19 +1,17 @@ +import os +import sys import argparse +from tqdm import tqdm import numpy as np - from PIL import Image -from tqdm import tqdm - -from robosat.tiles import tiles_from_slippy_map -from robosat.config import load_config -from robosat.features.parking import ParkingHandler +import pkgutil +from pathlib import Path +from importlib import import_module - -# Register post-processing handlers here; they need to support a `apply(tile, mask)` function -# for handling one mask and a `save(path)` function for GeoJSON serialization to a file. -handlers = {"parking": ParkingHandler} +from robosat.config import load_config +from robosat.tiles import tiles_from_slippy_map def add_parser(subparser): @@ -23,29 +21,40 @@ def add_parser(subparser): formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) + parser.add_argument("--type", type=str, required=True, help="type of feature to extract") + parser.add_argument("--config", type=str, required=True, help="path to configuration file") + parser.add_argument("--path", type=str, help="path to user's extension modules dir") parser.add_argument("masks", type=str, help="slippy map directory with segmentation masks") - parser.add_argument("--type", type=str, required=True, choices=handlers.keys(), help="type of feature to extract") - parser.add_argument("--dataset", type=str, required=True, help="path to dataset configuration file") parser.add_argument("out", type=str, help="path to GeoJSON file to store features in") parser.set_defaults(func=main) def main(args): - dataset = load_config(args.dataset) - labels = dataset["common"]["classes"] - assert set(labels).issuperset(set(handlers.keys())), "handlers have a class label" + module_search_path = [args.path] if args.path else [] + module_search_path.append(os.path.join(Path(__file__).parent.parent, "features")) + modules = [(path, name) for path, name, _ in pkgutil.iter_modules(module_search_path) if name != "core"] + if args.type not in [name for _, name in modules]: + sys.exit("Unknown type, thoses available are {}".format([name for _, name in modules])) + + config = load_config(args.config) + labels = config["classes"]["titles"] + if args.type not in labels: + sys.exit("The type you asked is not consistent with yours classes in the config file provided.") index = labels.index(args.type) - handler = handlers[args.type]() + if args.path: + sys.path.append(args.path) + module = import_module(args.type) + else: + module = import_module("robosat.features.{}".format(args.type)) - tiles = list(tiles_from_slippy_map(args.masks)) + handler = getattr(module, "{}Handler".format(args.type.title()))() - for tile, path in tqdm(tiles, ascii=True, unit="mask"): + for tile, path in tqdm(list(tiles_from_slippy_map(args.masks)), ascii=True, unit="mask"): image = np.array(Image.open(path).convert("P"), dtype=np.uint8) mask = (image == index).astype(np.uint8) - handler.apply(tile, mask) handler.save(args.out) diff --git a/robosat/tools/masks.py b/robosat/tools/masks.py index 310956c0..dd974c86 100644 --- a/robosat/tools/masks.py +++ b/robosat/tools/masks.py @@ -9,6 +9,7 @@ from robosat.tiles import tiles_from_slippy_map from robosat.colors import make_palette +from robosat.utils import web_ui def add_parser(subparser): @@ -18,9 +19,12 @@ def add_parser(subparser): formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) + parser.add_argument("--config", type=str, required=True, help="path to configuration file") + parser.add_argument("--weights", type=float, nargs="+", help="weights for weighted average soft-voting") + parser.add_argument("--web_ui", type=str, help="web ui client base url") + parser.add_argument("--web_ui_template", type=str, help="path to an alternate web ui template") parser.add_argument("masks", type=str, help="slippy map directory to save masks to") parser.add_argument("probs", type=str, nargs="+", help="slippy map directories with class probabilities") - parser.add_argument("--weights", type=float, nargs="+", help="weights for weighted average soft-voting") parser.set_defaults(func=main) @@ -59,7 +63,9 @@ def load(path): mask = softvote(probs, axis=0, weights=args.weights) mask = mask.astype(np.uint8) - palette = make_palette("denim", "orange") + config = load_config(args.config) + palette = make_palette(config["classes"]["colors"][0], config["classes"]["colors"][1]) + out = Image.fromarray(mask, mode="P") out.putpalette(palette) @@ -68,6 +74,11 @@ def load(path): path = os.path.join(args.masks, str(z), str(x), str(y) + ".png") out.save(path, optimize=True) + if args.web_ui: + template = "leaflet.html" if not args.web_ui_template else args.web_ui_template + tiles = [tile for tile, _ in list(tiles_from_slippy_map(args.probs[0]))] + web_ui(args.masks, args.web_ui, tiles, tiles, "png", template) + def softvote(probs, axis=0, weights=None): """Weighted average soft-voting to transform class probabilities into class indices. diff --git a/robosat/tools/merge.py b/robosat/tools/merge.py index e1e79f17..eac4e7ca 100644 --- a/robosat/tools/merge.py +++ b/robosat/tools/merge.py @@ -15,8 +15,8 @@ def add_parser(subparser): "merge", help="merged adjacent GeoJSON features", formatter_class=argparse.ArgumentDefaultsHelpFormatter ) - parser.add_argument("features", type=str, help="GeoJSON file to read features from") parser.add_argument("--threshold", type=int, required=True, help="minimum distance to adjacent features, in m") + parser.add_argument("features", type=str, help="GeoJSON file to read features from") parser.add_argument("out", type=str, help="path to GeoJSON to save merged features to") parser.set_defaults(func=main) diff --git a/robosat/tools/predict.py b/robosat/tools/predict.py index 36436b52..64251b66 100644 --- a/robosat/tools/predict.py +++ b/robosat/tools/predict.py @@ -14,10 +14,12 @@ from PIL import Image from robosat.datasets import BufferedSlippyMapDirectory +from robosat.tiles import tiles_from_slippy_map from robosat.unet import UNet from robosat.config import load_config -from robosat.colors import continuous_palette_for_color -from robosat.transforms import ConvertImageMode, ImageToTensor +from robosat.colors import continuous_palette_for_color, make_palette +from robosat.transforms import ImageToTensor +from robosat.utils import web_ui def add_parser(subparser): @@ -32,29 +34,28 @@ def add_parser(subparser): parser.add_argument("--overlap", type=int, default=32, help="tile pixel overlap to predict on") parser.add_argument("--tile_size", type=int, required=True, help="tile size for slippy map tiles") parser.add_argument("--workers", type=int, default=0, help="number of workers pre-processing images") + parser.add_argument("--config", type=str, required=True, help="path to configuration file") + parser.add_argument("--masks_output", action="store_true", help="output masks rather than probs") + parser.add_argument("--web_ui", type=str, help="web ui base url") + parser.add_argument("--web_ui_template", type=str, help="path to an alternate web ui template") parser.add_argument("tiles", type=str, help="directory to read slippy map image tiles from") parser.add_argument("probs", type=str, help="directory to save slippy map probability masks to") - parser.add_argument("--model", type=str, required=True, help="path to model configuration file") - parser.add_argument("--dataset", type=str, required=True, help="path to dataset configuration file") parser.set_defaults(func=main) def main(args): - model = load_config(args.model) - dataset = load_config(args.dataset) + config = load_config(args.config) + num_classes = len(config["classes"]["titles"]) - cuda = model["common"]["cuda"] - - device = torch.device("cuda" if cuda else "cpu") + if torch.cuda.is_available(): + device = torch.device("cuda") + torch.backends.cudnn.benchmark = True + else: + device = torch.device("cpu") def map_location(storage, _): - return storage.cuda() if cuda else storage.cpu() - - if cuda and not torch.cuda.is_available(): - sys.exit("Error: CUDA requested but not available") - - num_classes = len(dataset["common"]["classes"]) + return storage.cuda() if torch.cuda.is_available() else storage.cpu() # https://github.com/pytorch/pytorch/issues/7178 chkpt = torch.load(args.checkpoint, map_location=map_location) @@ -62,19 +63,21 @@ def map_location(storage, _): net = UNet(num_classes).to(device) net = nn.DataParallel(net) - if cuda: - torch.backends.cudnn.benchmark = True - net.load_state_dict(chkpt["state_dict"]) net.eval() mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] - transform = Compose([ConvertImageMode(mode="RGB"), ImageToTensor(), Normalize(mean=mean, std=std)]) + transform = Compose([ImageToTensor(), Normalize(mean=mean, std=std)]) directory = BufferedSlippyMapDirectory(args.tiles, transform=transform, size=args.tile_size, overlap=args.overlap) loader = DataLoader(directory, batch_size=args.batch_size, num_workers=args.workers) + if args.masks_output: + palette = make_palette(config["classes"]["colors"][0], config["classes"]["colors"][1]) + else: + palette = continuous_palette_for_color("pink", 256) + # don't track tensors with autograd during prediction with torch.no_grad(): for images, tiles in tqdm(loader, desc="Eval", unit="batch", ascii=True): @@ -90,22 +93,23 @@ def map_location(storage, _): # we predicted on buffered tiles; now get back probs for original image prob = directory.unbuffer(prob) - # Quantize the floating point probabilities in [0,1] to [0,255] and store - # a single-channel `.png` file with a continuous color palette attached. - assert prob.shape[0] == 2, "single channel requires binary model" - assert np.allclose(np.sum(prob, axis=0), 1.), "single channel requires probabilities to sum up to one" - foreground = prob[1:, :, :] + assert np.allclose(np.sum(prob, axis=0), 1.0), "single channel requires probabilities to sum up to one" - anchors = np.linspace(0, 1, 256) - quantized = np.digitize(foreground, anchors).astype(np.uint8) + if args.masks_output: + image = np.around(prob[1:, :, :]).astype(np.uint8).squeeze() + else: + image = (prob[1:, :, :] * 255).astype(np.uint8).squeeze() - palette = continuous_palette_for_color("pink", 256) - - out = Image.fromarray(quantized.squeeze(), mode="P") + out = Image.fromarray(image, mode="P") out.putpalette(palette) os.makedirs(os.path.join(args.probs, str(z), str(x)), exist_ok=True) path = os.path.join(args.probs, str(z), str(x), str(y) + ".png") out.save(path, optimize=True) + + if args.web_ui: + template = "leaflet.html" if not args.web_ui_template else args.web_ui_template + tiles = [tile for tile, _ in tiles_from_slippy_map(args.tiles)] + web_ui(args.probs, args.web_ui, tiles, tiles, "png", template) diff --git a/robosat/tools/rasterize.py b/robosat/tools/rasterize.py index 9a1b84c5..04749816 100644 --- a/robosat/tools/rasterize.py +++ b/robosat/tools/rasterize.py @@ -14,10 +14,13 @@ from rasterio.features import rasterize from rasterio.warp import transform from supermercado import burntiles +from shapely.geometry import mapping from robosat.config import load_config -from robosat.colors import make_palette +from robosat.colors import make_palette, complementary_palette from robosat.tiles import tiles_from_csv +from robosat.utils import web_ui +from robosat.log import Log def add_parser(subparser): @@ -25,57 +28,48 @@ def add_parser(subparser): "rasterize", help="rasterize features to label masks", formatter_class=argparse.ArgumentDefaultsHelpFormatter ) - parser.add_argument("features", type=str, help="path to GeoJSON features file") - parser.add_argument("tiles", type=str, help="path to .csv tiles file") - parser.add_argument("out", type=str, help="directory to write converted images") - parser.add_argument("--dataset", type=str, required=True, help="path to dataset configuration file") + parser.add_argument("--config", type=str, required=True, help="path to configuration file") parser.add_argument("--zoom", type=int, required=True, help="zoom level of tiles") parser.add_argument("--size", type=int, default=512, help="size of rasterized image tiles in pixels") + parser.add_argument("--web_ui", type=str, help="web ui client base url") + parser.add_argument("--web_ui_template", type=str, help="path to an alternate web ui template") + parser.add_argument("features", type=str, nargs="+", help="path to GeoJSON features file") + parser.add_argument("cover", type=str, help="path to csv tiles cover file") + parser.add_argument("out", type=str, help="directory to write converted images") parser.set_defaults(func=main) def feature_to_mercator(feature): - """Normalize feature and converts coords to 3857. + """Convert polygon feature coords to 3857. Args: feature: geojson feature to convert to mercator geometry. """ # Ref: https://gist.github.com/dnomadb/5cbc116aacc352c7126e779c29ab7abe - src_crs = CRS.from_epsg(4326) - dst_crs = CRS.from_epsg(3857) - - geometry = feature["geometry"] - if geometry["type"] == "Polygon": - xys = (zip(*part) for part in geometry["coordinates"]) - xys = (list(zip(*transform(src_crs, dst_crs, *xy))) for xy in xys) + # FIXME: We assume that GeoJSON input coordinates can't be anything else than EPSG:4326 + if feature["geometry"]["type"] == "Polygon": + xys = (zip(*ring) for ring in feature["geometry"]["coordinates"]) + xys = (list(zip(*transform(CRS.from_epsg(4326), CRS.from_epsg(3857), *xy))) for xy in xys) yield {"coordinates": list(xys), "type": "Polygon"} - elif geometry["type"] == "MultiPolygon": - for component in geometry["coordinates"]: - xys = (zip(*part) for part in component) - xys = (list(zip(*transform(src_crs, dst_crs, *xy))) for xy in xys) - - yield {"coordinates": list(xys), "type": "Polygon"} - -def burn(tile, features, size): +def burn(tile, features, size, burn_value=1): """Burn tile with features. Args: tile: the mercantile tile to burn. features: the geojson features to burn. size: the size of burned image. + burn_value: the value you want in the output raster where a shape exists Returns: image: rasterized file of size with features burned. """ - # the value you want in the output raster where a shape exists - burnval = 1 - shapes = ((geometry, burnval) for feature in features for geometry in feature_to_mercator(feature)) + shapes = ((geometry, burn_value) for feature in features for geometry in feature_to_mercator(feature)) bounds = mercantile.xy_bounds(tile) transform = from_bounds(*bounds, size, size) @@ -84,40 +78,62 @@ def burn(tile, features, size): def main(args): - dataset = load_config(args.dataset) + config = load_config(args.config) - classes = dataset["common"]["classes"] - colors = dataset["common"]["colors"] + classes = config["classes"]["titles"] + colors = config["classes"]["colors"] assert len(classes) == len(colors), "classes and colors coincide" - assert len(colors) == 2, "only binary models supported right now" - bg = colors[0] - fg = colors[1] os.makedirs(args.out, exist_ok=True) # We can only rasterize all tiles at a single zoom. - assert all(tile.z == args.zoom for tile in tiles_from_csv(args.tiles)) - - with open(args.features) as f: - fc = json.load(f) + assert all(tile.z == args.zoom for tile in tiles_from_csv(args.cover)) # Find all tiles the features cover and make a map object for quick lookup. feature_map = collections.defaultdict(list) - for i, feature in enumerate(tqdm(fc["features"], ascii=True, unit="feature")): + log = Log(os.path.join(args.out, "log"), out=sys.stderr) - if feature["geometry"]["type"] != "Polygon": - continue + def parse_polygon(feature_map, polygon, i): try: - for tile in burntiles.burn([feature], zoom=args.zoom): - feature_map[mercantile.Tile(*tile)].append(feature) + for i, ring in enumerate(polygon["coordinates"]): # GeoJSON coordinates could be N dimensionals + polygon["coordinates"][i] = [[x, y] for point in ring for x, y in zip([point[0]], [point[1]])] + + for tile in burntiles.burn([{"type": "feature", "geometry": polygon}], zoom=args.zoom): + feature_map[mercantile.Tile(*tile)].append({"type": "feature", "geometry": polygon}) + except ValueError as e: - print("Warning: invalid feature {}, skipping".format(i), file=sys.stderr) - continue + log.log("Warning: invalid feature {}, skipping".format(i)) + + return feature_map + + def parse_geometry(feature_map, geometry, i): + + if geometry["type"] == "Polygon": + feature_map = parse_polygon(feature_map, geometry, i) + + elif geometry["type"] == "MultiPolygon": + for polygon in geometry["coordinates"]: + feature_map = parse_polygon(feature_map, {"type": "Polygon", "coordinates": polygon}, i) + else: + log.log("Notice: {} is a non surfacic geometry type, skipping feature {}".format(geometry["type"], i)) + + return feature_map + + for feature in args.features: + with open(feature) as f: + fc = json.load(f) + for i, feature in enumerate(tqdm(fc["features"], ascii=True, unit="feature")): + + if feature["geometry"]["type"] == "GeometryCollection": + for geometry in feature["geometry"]["geometries"]: + feature_map = parse_geometry(feature_map, geometry, i) + else: + feature_map = parse_geometry(feature_map, feature["geometry"], i) # Burn features to tiles and write to a slippy map directory. - for tile in tqdm(list(tiles_from_csv(args.tiles)), ascii=True, unit="tile"): + for tile in tqdm(list(tiles_from_csv(args.cover)), ascii=True, unit="tile"): if tile in feature_map: out = burn(tile, feature_map[tile], args.size) else: @@ -134,7 +150,13 @@ def main(args): out = Image.fromarray(out, mode="P") - palette = make_palette(bg, fg) - out.putpalette(palette) + out_path = os.path.join(args.out, str(tile.z), str(tile.x)) + os.makedirs(out_path, exist_ok=True) + + out.putpalette(complementary_palette(make_palette(colors[0], colors[1]))) + out.save(os.path.join(out_path, "{}.png".format(tile.y)), optimize=True) - out.save(out_path, optimize=True) + if args.web_ui: + template = "leaflet.html" if not args.web_ui_template else args.web_ui_template + tiles = [tile for tile in tiles_from_csv(args.cover)] + web_ui(args.out, args.web_ui, tiles, tiles, "png", template) diff --git a/robosat/tools/serve.py b/robosat/tools/serve.py index a3e6252c..a815d280 100644 --- a/robosat/tools/serve.py +++ b/robosat/tools/serve.py @@ -12,6 +12,7 @@ import mercantile import requests +import cv2 from PIL import Image from flask import Flask, send_file, render_template, abort @@ -19,7 +20,7 @@ from robosat.unet import UNet from robosat.config import load_config from robosat.colors import make_palette -from robosat.transforms import ConvertImageMode, ImageToTensor +from robosat.transforms import ImageToTensor """ Simple tile server running a segmentation model on the fly. @@ -62,7 +63,7 @@ def tile(z, x, y): if not res: abort(500) - image = Image.open(res) + image = cv2.imdecode(np.asarray(bytearray(res.read()), dtype=np.uint8), cv2.COLOR_BGR2RGB) mask = predictor.segment(image) @@ -83,8 +84,7 @@ def add_parser(subparser): formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) - parser.add_argument("--model", type=str, required=True, help="path to model configuration file") - parser.add_argument("--dataset", type=str, required=True, help="path to dataset configuration file") + parser.add_argument("--config", type=str, required=True, help="path to configuration file") parser.add_argument("--url", type=str, help="endpoint with {z}/{x}/{y} variables to fetch image tiles from") parser.add_argument("--checkpoint", type=str, required=True, help="model checkpoint to load") @@ -96,13 +96,7 @@ def add_parser(subparser): def main(args): - model = load_config(args.model) - dataset = load_config(args.dataset) - - cuda = model["common"]["cuda"] - - if cuda and not torch.cuda.is_available(): - sys.exit("Error: CUDA requested but not available") + config = load_config(args.config) global size size = args.tile_size @@ -120,7 +114,7 @@ def main(args): tiles = args.url global predictor - predictor = Predictor(args.checkpoint, model, dataset) + predictor = Predictor(args.checkpoint, config) app.run(host=args.host, port=args.port, threaded=False) @@ -133,17 +127,13 @@ def send_png(image): class Predictor: - def __init__(self, checkpoint, model, dataset): - cuda = model["common"]["cuda"] - - assert torch.cuda.is_available() or not cuda, "cuda is available when requested" + def __init__(self, checkpoint, config): - self.cuda = cuda - self.device = torch.device("cuda" if cuda else "cpu") + self.cuda = torch.cuda.is_available() + self.device = torch.device("cuda" if self.cuda else "cpu") self.checkpoint = checkpoint - self.model = model - self.dataset = dataset + self.config = config self.net = self.net_from_chkpt_() @@ -152,7 +142,7 @@ def segment(self, image): with torch.no_grad(): mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] - transform = Compose([ConvertImageMode(mode="RGB"), ImageToTensor(), Normalize(mean=mean, std=std)]) + transform = Compose([ImageToTensor(), Normalize(mean=mean, std=std)]) image = transform(image) batch = image.unsqueeze(0).to(self.device) @@ -166,7 +156,7 @@ def segment(self, image): mask = Image.fromarray(mask, mode="P") - palette = make_palette(*self.dataset["common"]["colors"]) + palette = make_palette(*self.config["common"]["colors"]) mask.putpalette(palette) return mask @@ -178,7 +168,7 @@ def map_location(storage, _): # https://github.com/pytorch/pytorch/issues/7178 chkpt = torch.load(self.checkpoint, map_location=map_location) - num_classes = len(self.dataset["common"]["classes"]) + num_classes = len(self.config["classes"]["titles"]) net = UNet(num_classes).to(self.device) net = nn.DataParallel(net) diff --git a/robosat/tools/subset.py b/robosat/tools/subset.py index 14f0a70a..d47858e7 100644 --- a/robosat/tools/subset.py +++ b/robosat/tools/subset.py @@ -1,10 +1,14 @@ import os +import sys import argparse import shutil +from glob import glob from tqdm import tqdm -from robosat.tiles import tiles_from_slippy_map, tiles_from_csv +from robosat.tiles import tiles_from_csv +from robosat.utils import web_ui +from robosat.log import Log def add_parser(subparser): @@ -13,26 +17,43 @@ def add_parser(subparser): help="filter images in a slippy map directory using a csv", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) - parser.add_argument("images", type=str, help="directory to read slippy map image tiles from for filtering") - parser.add_argument("tiles", type=str, help="csv to filter images by") - parser.add_argument("out", type=str, help="directory to save filtered images to") + parser.add_argument("dir", type=str, help="directory to read slippy map tiles from for filtering") + parser.add_argument("cover", type=str, help="csv cover to filter tiles by") + parser.add_argument("out", type=str, help="directory to save filtered tiles to") + parser.add_argument("--move", action="store_true", help="move files from src to dst (rather than copy them)") + parser.add_argument("--web_ui", type=str, help="web ui base url") + parser.add_argument("--web_ui_template", type=str, help="path to an alternate web ui template") parser.set_defaults(func=main) def main(args): - images = tiles_from_slippy_map(args.images) + log = Log(os.path.join(args.out, "log"), out=sys.stderr) - tiles = set(tiles_from_csv(args.tiles)) + tiles = set(tiles_from_csv(args.cover)) + extension = "" - for tile, src in tqdm(list(images), desc="Subset", unit="image", ascii=True): - if tile not in tiles: - continue - - # The extention also includes the period. - extention = os.path.splitext(src)[1] + for tile in tqdm(tiles, desc="Subset", unit="tiles", ascii=True): - os.makedirs(os.path.join(args.out, str(tile.z), str(tile.x)), exist_ok=True) - dst = os.path.join(args.out, str(tile.z), str(tile.x), "{}{}".format(tile.y, extention)) - - shutil.copyfile(src, dst) + paths = glob(os.path.join(args.dir, str(tile.z), str(tile.x), "{}.*".format(tile.y))) + if len(paths) != 1: + log.log("Warning: {} skipped.".format(tile)) + continue + src = paths[0] + + try: + extension = os.path.splitext(src)[1][1:] + dst = os.path.join(args.out, str(tile.z), str(tile.x), "{}.{}".format(tile.y, extension)) + if not os.path.isdir(os.path.join(args.out, str(tile.z), str(tile.x))): + os.makedirs(os.path.join(args.out, str(tile.z), str(tile.x)), exist_ok=True) + if args.move: + assert os.path.isfile(src) + shutil.move(src, dst) + else: + shutil.copyfile(src, dst) + except: + sys.exit("Error: Unable to process {}".format(tile)) + + if args.web_ui: + template = "leaflet.html" if not args.web_ui_template else args.web_ui_template + web_ui(args.out, args.web_ui, tiles, tiles, extension, template) diff --git a/robosat/tools/templates/compare.html b/robosat/tools/templates/compare.html new file mode 100644 index 00000000..839054bd --- /dev/null +++ b/robosat/tools/templates/compare.html @@ -0,0 +1,57 @@ + + + + RoboSat Compare WebUI + + + + +
+

Right Arrow: next image to compare, if any.

+

Left Arrow: previous image to compare, if any.

+

SpaceBar: select, or unselect, the current image.

+

Esc: ask to copy selected images, as a text cover, to clipboard.

+
+ + + diff --git a/robosat/tools/templates/leaflet.html b/robosat/tools/templates/leaflet.html new file mode 100644 index 00000000..2fe49cac --- /dev/null +++ b/robosat/tools/templates/leaflet.html @@ -0,0 +1,30 @@ + + + + RoboSat Leaflet WebUI + + + + + + + + +
+ + + diff --git a/robosat/tools/tile.py b/robosat/tools/tile.py new file mode 100644 index 00000000..10a29c7b --- /dev/null +++ b/robosat/tools/tile.py @@ -0,0 +1,122 @@ +import os +import sys +import math +import argparse +from tqdm import tqdm + +import numpy as np +from PIL import Image + +import mercantile + +from rasterio import open as rasterio_open +from rasterio.vrt import WarpedVRT +from rasterio.enums import Resampling +from rasterio.warp import transform_bounds, calculate_default_transform +from rasterio.transform import from_bounds + +from robosat.config import load_config +from robosat.colors import make_palette +from robosat.utils import web_ui + + +def add_parser(subparser): + parser = subparser.add_parser( + "tile", help="tile a raster image or label", formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument("--size", type=int, default=512, help="size of tiles side in pixels") + parser.add_argument("--zoom", type=int, required=True, help="zoom level of tiles") + parser.add_argument("--type", type=str, choices=["image", "label"], default="image", help="image or label tiling") + parser.add_argument("--config", type=str, help="path to configuration file, mandatory for label tiling") + parser.add_argument("--no_data", type=int, help="color considered as no data [0-255]. Skip related tile") + parser.add_argument("--web_ui", type=str, help="web ui base url") + parser.add_argument("--web_ui_template", type=str, help="path to an alternate web ui template") + parser.add_argument("raster", type=str, help="path to the raster to tile") + parser.add_argument("out", type=str, help="directory to write tiles") + + parser.set_defaults(func=main) + + +def main(args): + + if args.type == "label": + try: + config = load_config(args.config) + except: + sys.exit("Error: Unable to load DataSet config file") + + classes = config["classes"]["titles"] + colors = config["classes"]["colors"] + assert len(classes) == len(colors), "classes and colors coincide" + assert len(colors) == 2, "only binary models supported right now" + + try: + raster = rasterio_open(args.raster) + w, s, e, n = bounds = transform_bounds(raster.crs, "EPSG:4326", *raster.bounds) + transform, _, _ = calculate_default_transform(raster.crs, "EPSG:3857", raster.width, raster.height, *bounds) + except: + sys.exit("Error: Unable to load raster or deal with it's projection") + + tiles = [mercantile.Tile(x=x, y=y, z=z) for x, y, z in mercantile.tiles(w, s, e, n, args.zoom)] + tiles_nodata = [] + + for tile in tqdm(tiles, desc="Tiling", unit="tile", ascii=True): + + w, s, e, n = tile_bounds = mercantile.xy_bounds(tile) + + # Inspired by Rio-Tiler, cf: https://github.com/mapbox/rio-tiler/pull/45 + warp_vrt = WarpedVRT( + raster, + crs="EPSG:3857", + resampling=Resampling.bilinear, + add_alpha=False, + transform=from_bounds(*tile_bounds, args.size, args.size), + width=math.ceil((e - w) / transform.a), + height=math.ceil((s - n) / transform.e), + ) + data = warp_vrt.read(out_shape=(len(raster.indexes), args.size, args.size), window=warp_vrt.window(w, s, e, n)) + + # If no_data is set, remove all tiles with at least one whole border filled only with no_data (on all bands) + if type(args.no_data) is not None and ( + np.all(data[:, 0, :] == args.no_data) + or np.all(data[:, -1, :] == args.no_data) + or np.all(data[:, :, 0] == args.no_data) + or np.all(data[:, :, -1] == args.no_data) + ): + tiles_nodata.append(tile) + continue + + C, W, H = data.shape + + os.makedirs(os.path.join(args.out, str(args.zoom), str(tile.x)), exist_ok=True) + path = os.path.join(args.out, str(args.zoom), str(tile.x), str(tile.y)) + + if args.type == "label": + assert C == 1, "Error: Label raster input should be 1 band" + + ext = "png" + img = Image.fromarray(np.squeeze(data, axis=0), mode="P") + img.putpalette(make_palette(colors[0], colors[1])) + img.save("{}.{}".format(path, ext), optimize=True) + + elif args.type == "image": + assert C == 1 or C == 3, "Error: Image raster input should be either 1 or 3 bands" + + # GeoTiff could be 16 or 32bits + if data.dtype == "uint16": + data = np.uint8(data / 256) + elif data.dtype == "uint32": + data = np.uint8(data / (256 * 256)) + + if C == 1: + ext = "png" + Image.fromarray(np.squeeze(data, axis=0), mode="L").save("{}.{}".format(path, ext), optimize=True) + elif C == 3: + ext = "webp" + Image.fromarray(np.moveaxis(data, 0, 2), mode="RGB").save("{}.{}".format(path, ext), optimize=True) + + if args.web_ui: + template = "leaflet.html" if not args.web_ui_template else args.web_ui_template + tiles = [tile for tile in tiles if tile not in tiles_nodata] + web_ui(args.out, args.web_ui, tiles, tiles, ext, template) diff --git a/robosat/tools/train.py b/robosat/tools/train.py index 85d6baa9..954a7aeb 100644 --- a/robosat/tools/train.py +++ b/robosat/tools/train.py @@ -11,16 +11,15 @@ from torch.nn import DataParallel from torch.optim import Adam from torch.utils.data import DataLoader -from torchvision.transforms import Resize, CenterCrop, Normalize +from torchvision.transforms import Normalize from tqdm import tqdm from robosat.transforms import ( JointCompose, JointTransform, - JointRandomHorizontalFlip, - JointRandomRotation, - ConvertImageMode, + JointResize, + JointRandomFlipOrRotate, ImageToTensor, MaskToTensor, ) @@ -44,87 +43,106 @@ def add_parser(subparser): "train", help="trains model on dataset", formatter_class=argparse.ArgumentDefaultsHelpFormatter ) - parser.add_argument("--model", type=str, required=True, help="path to model configuration file") - parser.add_argument("--dataset", type=str, required=True, help="path to dataset configuration file") + parser.add_argument("--config", type=str, required=True, help="path to configuration file") parser.add_argument("--checkpoint", type=str, required=False, help="path to a model checkpoint (to retrain)") - parser.add_argument("--resume", type=bool, default=False, help="resume training or fine-tuning (if checkpoint)") + parser.add_argument("--resume", action="store_true", help="resume training (imply to provide a checkpoint)") parser.add_argument("--workers", type=int, default=0, help="number of workers pre-processing images") + parser.add_argument("--dataset", type=int, help="if set, override dataset path value from config file") + parser.add_argument("--epochs", type=int, help="if set, override epochs value from config file") + parser.add_argument("--lr", type=float, help="if set, override learning rate value from config file") + parser.add_argument("out", type=str, help="directory to save checkpoint .pth files and log") parser.set_defaults(func=main) def main(args): - model = load_config(args.model) - dataset = load_config(args.dataset) + config = load_config(args.config) + lr = args.lr if args.lr else config["model"]["lr"] + dataset_path = args.dataset if args.dataset else config["dataset"]["path"] + num_epochs = args.epochs if args.epochs else config["model"]["epochs"] - device = torch.device("cuda" if model["common"]["cuda"] else "cpu") + log = Log(os.path.join(args.out, "log")) - if model["common"]["cuda"] and not torch.cuda.is_available(): - sys.exit("Error: CUDA requested but not available") + if torch.cuda.is_available(): + device = torch.device("cuda") - os.makedirs(model["common"]["checkpoint"], exist_ok=True) - - num_classes = len(dataset["common"]["classes"]) - net = UNet(num_classes) - net = DataParallel(net) - net = net.to(device) - - if model["common"]["cuda"]: torch.backends.cudnn.benchmark = True - - try: - weight = torch.Tensor(dataset["weights"]["values"]) - except KeyError: - if model["opt"]["loss"] in ("CrossEntropy", "mIoU", "Focal"): + log.log("RoboSat - training on {} GPUs, with {} workers".format(torch.cuda.device_count(), args.workers)) + else: + device = torch.device("cpu") + log.log("RoboSat - training on CPU, with {} workers", format(args.workers)) + + num_classes = len(config["classes"]["titles"]) + num_channels = 0 + for channel in config["channels"]: + num_channels += len(channel["bands"]) + pretrained = config["model"]["pretrained"] + net = DataParallel(UNet(num_classes, num_channels=num_channels, pretrained=pretrained)).to(device) + + if config["model"]["loss"] in ("CrossEntropy", "mIoU", "Focal"): + try: + weight = torch.Tensor(config["classes"]["weights"]) + except KeyError: sys.exit("Error: The loss function used, need dataset weights values") - optimizer = Adam(net.parameters(), lr=model["opt"]["lr"], weight_decay=model["opt"]["decay"]) + optimizer = Adam(net.parameters(), lr=lr, weight_decay=config["model"]["decay"]) resume = 0 if args.checkpoint: def map_location(storage, _): - return storage.cuda() if model["common"]["cuda"] else storage.cpu() + return storage.cuda() if torch.cuda.is_available() else storage.cpu() # https://github.com/pytorch/pytorch/issues/7178 chkpt = torch.load(args.checkpoint, map_location=map_location) net.load_state_dict(chkpt["state_dict"]) + log.log("Using checkpoint: {}".format(args.checkpoint)) if args.resume: optimizer.load_state_dict(chkpt["optimizer"]) resume = chkpt["epoch"] - if model["opt"]["loss"] == "CrossEntropy": + if config["model"]["loss"] == "CrossEntropy": criterion = CrossEntropyLoss2d(weight=weight).to(device) - elif model["opt"]["loss"] == "mIoU": + elif config["model"]["loss"] == "mIoU": criterion = mIoULoss2d(weight=weight).to(device) - elif model["opt"]["loss"] == "Focal": + elif config["model"]["loss"] == "Focal": criterion = FocalLoss2d(weight=weight).to(device) - elif model["opt"]["loss"] == "Lovasz": + elif config["model"]["loss"] == "Lovasz": criterion = LovaszLoss2d().to(device) else: - sys.exit("Error: Unknown [opt][loss] value !") + sys.exit("Error: Unknown [model][loss] value !") - train_loader, val_loader = get_dataset_loaders(model, dataset, args.workers) + train_loader, val_loader = get_dataset_loaders(dataset_path, config, args.workers) - num_epochs = model["opt"]["epochs"] if resume >= num_epochs: - sys.exit("Error: Epoch {} set in {} already reached by the checkpoint provided".format(num_epochs, args.model)) + sys.exit("Error: Epoch {} set in {} already reached by the checkpoint provided".format(num_epochs, args.config)) history = collections.defaultdict(list) - log = Log(os.path.join(model["common"]["checkpoint"], "log")) - - log.log("--- Hyper Parameters on Dataset: {} ---".format(dataset["common"]["dataset"])) - log.log("Batch Size:\t {}".format(model["common"]["batch_size"])) - log.log("Image Size:\t {}".format(model["common"]["image_size"])) - log.log("Learning Rate:\t {}".format(model["opt"]["lr"])) - log.log("Weight Decay:\t {}".format(model["opt"]["decay"])) - log.log("Loss function:\t {}".format(model["opt"]["loss"])) + + log.log("") + log.log("--- Input tensor from Dataset: {} ---".format(dataset_path)) + num_channel = 1 + for channel in config["channels"]: + for band in channel["bands"]: + log.log("Channel {}:\t\t {}[band: {}]".format(num_channel, channel["sub"], band)) + num_channel += 1 + log.log("") + log.log("--- Hyper Parameters ---") + log.log("Batch Size:\t\t {}".format(config["model"]["batch_size"])) + log.log("Image Size:\t\t {}".format(config["model"]["image_size"])) + log.log("Data Augmentation:\t {}".format(config["model"]["data_augmentation"])) + log.log("Learning Rate:\t\t {}".format(lr)) + log.log("Weight Decay:\t\t {}".format(config["model"]["decay"])) + log.log("Loss function:\t\t {}".format(config["model"]["loss"])) + log.log("ResNet pre-trained:\t {}".format(config["model"]["pretrained"])) if "weight" in locals(): - log.log("Weights :\t {}".format(dataset["weights"]["values"])) - log.log("---") + log.log("Weights :\t\t {}".format(config["dataset"]["weights"])) + log.log("") for epoch in range(resume, num_epochs): + + log.log("---") log.log("Epoch: {}/{}".format(epoch + 1, num_epochs)) train_hist = train(train_loader, num_classes, device, net, optimizer, criterion) @@ -132,7 +150,7 @@ def map_location(storage, _): "Train loss: {:.4f}, mIoU: {:.3f}, {} IoU: {:.3f}, MCC: {:.3f}".format( train_hist["loss"], train_hist["miou"], - dataset["common"]["classes"][1], + config["classes"]["titles"][1], train_hist["fg_iou"], train_hist["mcc"], ) @@ -144,21 +162,18 @@ def map_location(storage, _): val_hist = validate(val_loader, num_classes, device, net, criterion) log.log( "Validate loss: {:.4f}, mIoU: {:.3f}, {} IoU: {:.3f}, MCC: {:.3f}".format( - val_hist["loss"], val_hist["miou"], dataset["common"]["classes"][1], val_hist["fg_iou"], val_hist["mcc"] + val_hist["loss"], val_hist["miou"], config["classes"]["titles"][1], val_hist["fg_iou"], val_hist["mcc"] ) ) for k, v in val_hist.items(): history["val " + k].append(v) - - visual = "history-{:05d}-of-{:05d}.png".format(epoch + 1, num_epochs) - plot(os.path.join(model["common"]["checkpoint"], visual), history) - - checkpoint = "checkpoint-{:05d}-of-{:05d}.pth".format(epoch + 1, num_epochs) + visual_path = os.path.join(args.out, "history-{:05d}-of-{:05d}.png".format(epoch + 1, num_epochs)) + plot(visual_path, history) states = {"epoch": epoch + 1, "state_dict": net.state_dict(), "optimizer": optimizer.state_dict()} - - torch.save(states, os.path.join(model["common"]["checkpoint"], checkpoint)) + checkpoint_path = os.path.join(args.out, "checkpoint-{:05d}-of-{:05d}.pth".format(epoch + 1, num_epochs)) + torch.save(states, checkpoint_path) def train(loader, num_classes, device, net, optimizer, criterion): @@ -243,35 +258,35 @@ def validate(loader, num_classes, device, net, criterion): } -def get_dataset_loaders(model, dataset, workers): - target_size = (model["common"]["image_size"],) * 2 - batch_size = model["common"]["batch_size"] - path = dataset["common"]["dataset"] +def get_dataset_loaders(path, config, workers): + # Values computed on ImageNet DataSet mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] transform = JointCompose( [ - JointTransform(ConvertImageMode("RGB"), ConvertImageMode("P")), - JointTransform(Resize(target_size, Image.BILINEAR), Resize(target_size, Image.NEAREST)), - JointTransform(CenterCrop(target_size), CenterCrop(target_size)), - JointRandomHorizontalFlip(0.5), - JointRandomRotation(0.5, 90), - JointRandomRotation(0.5, 90), - JointRandomRotation(0.5, 90), + JointResize(config["model"]["image_size"]), + JointRandomFlipOrRotate(config["model"]["data_augmentation"]), JointTransform(ImageToTensor(), MaskToTensor()), JointTransform(Normalize(mean=mean, std=std), None), ] ) train_dataset = SlippyMapTilesConcatenation( - [os.path.join(path, "training", "images")], os.path.join(path, "training", "labels"), transform + os.path.join(path, "training"), + config["channels"], + os.path.join(path, "training", "labels"), + joint_transform=transform, ) val_dataset = SlippyMapTilesConcatenation( - [os.path.join(path, "validation", "images")], os.path.join(path, "validation", "labels"), transform + os.path.join(path, "validation"), + config["channels"], + os.path.join(path, "validation", "labels"), + joint_transform=transform, ) + batch_size = config["model"]["batch_size"] train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=workers) val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=True, num_workers=workers) diff --git a/robosat/tools/weights.py b/robosat/tools/weights.py index c154fffd..af68dd1e 100644 --- a/robosat/tools/weights.py +++ b/robosat/tools/weights.py @@ -10,7 +10,7 @@ from robosat.config import load_config from robosat.datasets import SlippyMapTiles -from robosat.transforms import ConvertImageMode, MaskToTensor +from robosat.transforms import MaskToTensor def add_parser(subparser): @@ -18,20 +18,18 @@ def add_parser(subparser): "weights", help="computes class weights on dataset", formatter_class=argparse.ArgumentDefaultsHelpFormatter ) - parser.add_argument("--dataset", type=str, required=True, help="path to dataset configuration file") + parser.add_argument("--config", type=str, required=True, help="path to configuration file") parser.set_defaults(func=main) def main(args): - dataset = load_config(args.dataset) + config = load_config(args.config) + path = config["dataset"]["path"] + num_classes = len(config["classes"]["titles"]) - path = dataset["common"]["dataset"] - num_classes = len(dataset["common"]["classes"]) - - train_transform = Compose([ConvertImageMode(mode="P"), MaskToTensor()]) - - train_dataset = SlippyMapTiles(os.path.join(path, "training", "labels"), transform=train_transform) + train_transform = Compose([MaskToTensor()]) + train_dataset = SlippyMapTiles(os.path.join(path, "training", "labels"), "mask", transform=train_transform) n = 0 counts = np.zeros(num_classes, dtype=np.int64) diff --git a/robosat/transforms.py b/robosat/transforms.py index 9347fa34..96fa27b2 100644 --- a/robosat/transforms.py +++ b/robosat/transforms.py @@ -2,56 +2,43 @@ """ import random - import torch +import cv2 import numpy as np -from PIL import Image - -import torchvision - -# Callable to convert a RGB image into a PyTorch tensor. -ImageToTensor = torchvision.transforms.ToTensor - -class MaskToTensor: - """Callable to convert a PIL image into a PyTorch tensor. +class ImageToTensor: + """Callable to convert a NumPy H,W,C image into a PyTorch C,W,H tensor. """ def __call__(self, image): """Converts the image into a tensor. Args: - image: the PIL image to convert into a PyTorch tensor. + image: the image to convert into a PyTorch tensor. Returns: The converted PyTorch tensor. """ - return torch.from_numpy(np.array(image, dtype=np.uint8)).long() + return torch.from_numpy(np.moveaxis(image, 2, 0)).float() -class ConvertImageMode: - """Callable to convert a PIL image into a specific image mode (e.g. RGB, P) +class MaskToTensor: + """Callable to convert a NumPy H,W image into a PyTorch tensor. """ - def __init__(self, mode): - """Creates an `ConvertImageMode` instance. + def __call__(self, mask): + """Converts the mask into a tensor. Args: - mode: the PIL image mode string - """ - - self.mode = mode + mask: the mask to convert into a PyTorch tensor. - def __call__(self, image): - """Applies to mode conversion to an image. - - Args: - image: the PIL.Image image to transform. + Returns: + The converted PyTorch tensor. """ - return image.convert(self.mode) + return torch.from_numpy(mask).long() class JointCompose: @@ -67,21 +54,21 @@ def __init__(self, transforms): self.transforms = transforms - def __call__(self, images, mask): - """Applies multiple transformations to the images and the mask at the same time. + def __call__(self, image, mask): + """Applies multiple transformations to the image and its mask at the same time. Args: - images: the PIL.Image images to transform. - mask: the PIL.Image mask to transform. + image: the image to transform. + mask: the mask to transform. Returns: - The transformed PIL.Image (images, mask) tuple. + The transformed (image, mask) tuple. """ for transform in self.transforms: - images, mask = transform(images, mask) + image, mask = transform(image, mask) - return images, mask + return image, mask class JointTransform: @@ -94,128 +81,108 @@ def __init__(self, image_transform, mask_transform): """Creates an `JointTransform` instance. Args: - image_transform: the transformation to run on the images or `None` for no-op. + image_transform: the transformation to run on the image or `None` for no-op. mask_transform: the transformation to run on the mask or `None` for no-op. Returns: - The (images, mask) tuple with the transformations applied. + The (image, mask) tuple with the transformations applied. """ self.image_transform = image_transform self.mask_transform = mask_transform - def __call__(self, images, mask): - """Applies the transformations associated with images and their mask. + def __call__(self, image, mask): + """Applies the transformations associated with image and its mask. Args: - images: the PIL.Image images to transform. - mask: the PIL.Image mask to transform. + image: the image to transform. + mask: the mask to transform. Returns: - The PIL.Image (images, mask) tuple with images and mask transformed. + The (image, mask) tuple with the transformations applied. """ if self.image_transform is not None: - images = [self.image_transform(v) for v in images] + image = self.image_transform(image) if self.mask_transform is not None: mask = self.mask_transform(mask) - return images, mask + return image, mask -class JointRandomVerticalFlip: - """Callable to randomly flip images and its mask top to bottom. +class JointRandomFlipOrRotate: + """Callable to randomly rotate image and its mask. """ def __init__(self, p): - """Creates an `JointRandomVerticalFlip` instance. + """Creates an `JointRandomRotation` instance. Args: - p: the probability for flipping. + p: the probability for rotating. """ - + assert p >= 0.0 and p <= 1.0, "Probability must be expressed in 0-1 interval" self.p = p - def __call__(self, images, mask): - """Randomly flips images and their mask top to bottom. + def __call__(self, image, mask): + """Randomly rotates or flip image and its mask. Args: - images: the PIL.Image image to transform. - mask: the PIL.Image mask to transform. + image: the image to transform. + mask: the mask to transform. Returns: - The PIL.Image (images, mask) tuple with either images and mask flipped or none of them flipped. - """ - - if random.random() < self.p: - return [v.transpose(Image.FLIP_TOP_BOTTOM) for v in images], mask.transpose(Image.FLIP_TOP_BOTTOM) - else: - return images, mask - - -class JointRandomHorizontalFlip: - """Callable to randomly flip images and their mask left to right. - """ - - def __init__(self, p): - """Creates an `JointRandomHorizontalFlip` instance. - - Args: - p: the probability for flipping. + The (image, mask) tuple with either image and mask flip or rotated or kept untouched (but synced) """ - self.p = p + if random.random() > self.p: + return image, mask - def __call__(self, images, mask): - """Randomly flips image and their mask left to right. + transform = random.choice(["Rotate90", "Rotate180", "Rotate270", "HorizontalFlip", "VerticalFlip"]) - Args: - images: the PIL.Image images to transform. - mask: the PIL.Image mask to transform. + if transform == "Rotate90": + return cv2.flip(cv2.transpose(image), +1), cv2.flip(cv2.transpose(mask), +1) + elif transform == "Rotate180": + return cv2.flip(image, -1), cv2.flip(mask, -1) + elif transform == "Rotate270": + return cv2.flip(cv2.transpose(image), 0), cv2.flip(cv2.transpose(mask), 0) + elif transform == "HorizontalFlip": + return cv2.flip(image, +1), cv2.flip(mask, +1) + elif transform == "VerticalFlip": + return cv2.flip(image, 0), cv2.flip(mask, 0) - Returns: - The PIL.Image (images, mask) tuple with either images and mask flipped or none of them flipped. - """ - - if random.random() < self.p: - return [v.transpose(Image.FLIP_LEFT_RIGHT) for v in images], mask.transpose(Image.FLIP_LEFT_RIGHT) - else: - return images, mask - -class JointRandomRotation: - """Callable to randomly rotate images and their mask. +class JointResize: + """Callable to resize image and its mask """ - def __init__(self, p, degree): - """Creates an `JointRandomRotation` instance. + def __init__(self, size): + """Creates an `JointResize` instance. Args: - p: the probability for rotating. + size: the desired square side size """ + self.hw = (size, size) - self.p = p - - methods = {90: Image.ROTATE_90, 180: Image.ROTATE_180, 270: Image.ROTATE_270} - - if degree not in methods.keys(): - raise NotImplementedError("We only support multiple of 90 degree rotations for now") - - self.method = methods[degree] - - def __call__(self, images, mask): - """Randomly rotates images and their mask. + def __call__(self, image, mask): + """Resize image and its mask Args: - images: the PIL.Image image to transform. - mask: the PIL.Image mask to transform. + image: the image to transform. + mask: the mask to transform. Returns: - The PIL.Image (images, mask) tuple with either images and mask rotated or none of them rotated. + The (image, mask) tuple resized """ - if random.random() < self.p: - return [v.transpose(self.method) for v in images], mask.transpose(self.method) + if self.hw == image.shape[0:2]: + pass + elif self.hw[0] < image.shape[0] and self.hw[1] < image.shape[1]: + image = cv2.resize(image, self.hw, interpolation=cv2.INTER_AREA) else: - return images, mask + image = cv2.resize(image, self.hw, interpolation=cv2.INTER_LINEAR) + + if self.hw != mask.shape: + mask = cv2.resize(mask, self.hw, interpolation=cv2.INTER_NEAREST) + + return image, mask diff --git a/robosat/unet.py b/robosat/unet.py index bedabea1..4f9fd799 100644 --- a/robosat/unet.py +++ b/robosat/unet.py @@ -70,7 +70,7 @@ def forward(self, x): The networks output tensor. """ - return self.block(nn.functional.upsample(x, scale_factor=2, mode="nearest")) + return self.block(nn.functional.interpolate(x, scale_factor=2, mode="nearest")) class UNet(nn.Module): @@ -79,22 +79,30 @@ class UNet(nn.Module): Also known as AlbuNet due to its inventor Alexander Buslaev. """ - def __init__(self, num_classes, num_filters=32, pretrained=True): + def __init__(self, num_classes, num_channels=3, num_filters=32, pretrained=True): """Creates an `UNet` instance for semantic segmentation. Args: num_classes: number of classes to predict. - pretrained: use ImageNet pre-trained backbone feature extractor + num_channels: number of inputs channels (e.g bands) + pretrained: use ImageNet pre-trained ResNet Encoder weights """ super().__init__() - # Todo: make input channels configurable, not hard-coded to three channels for RGB - self.resnet = resnet50(pretrained=pretrained) - # Access resnet directly in forward pass; do not store refs here due to - # https://github.com/pytorch/pytorch/issues/8392 + assert num_channels + + if num_channels != 3: + weights = nn.init.xavier_uniform_(torch.zeros((64, num_channels, 7, 7))) + if pretrained: + for c in range(min(num_channels, 3)): + weights.data[:, c, :, :] = self.resnet.conv1.weight.data[:, c, :, :] + self.resnet.conv1 = nn.Conv2d(num_channels, 64, kernel_size=7, stride=2, padding=3, bias=False) + self.resnet.conv1.weight = nn.Parameter(weights) + + # No encoder reference, give a look at https://github.com/pytorch/pytorch/issues/8392 self.center = DecoderBlock(2048, num_filters * 8) diff --git a/robosat/utils.py b/robosat/utils.py index 60028221..34c64763 100644 --- a/robosat/utils.py +++ b/robosat/utils.py @@ -1,4 +1,11 @@ +import re +import os +import sys +import json import matplotlib +from pathlib import Path +from mercantile import feature +from robosat.tiles import pixel_to_location matplotlib.use("Agg") import matplotlib.pyplot as plt # noqa: E402 @@ -20,3 +27,39 @@ def plot(out, history): plt.savefig(out, format="png") plt.close() + + +def web_ui(out, base_url, coverage_tiles, selected_tiles, ext, template): + + try: + if os.path.isfile(template): + web_ui = open(template, "r").read() + else: + web_ui = open(os.path.join(Path(__file__).parent, "tools", "templates", template), "r").read() + except: + sys.exit("Unable to open Web UI template {}".format(template)) + + web_ui = re.sub("{{base_url}}", base_url, web_ui) + web_ui = re.sub("{{ext}}", ext, web_ui) + web_ui = re.sub("{{tiles}}", "tiles.json" if selected_tiles else "''", web_ui) + + if coverage_tiles: + # Could surely be improve, but for now, took the first tile to center on + tile = list(coverage_tiles)[0] + x, y, z = map(int, [tile.x, tile.y, tile.z]) + web_ui = re.sub("{{zoom}}", str(z), web_ui) + web_ui = re.sub("{{center}}", str(list(pixel_to_location(tile, 0.5, 0.5))[::-1]), web_ui) + + with open(os.path.join(out, "index.html"), "w", encoding="utf-8") as fp: + fp.write(web_ui) + + if selected_tiles: + with open(os.path.join(out, "tiles.json"), "w", encoding="utf-8") as fp: + fp.write('{"type":"FeatureCollection","features":[') + first = True + for tile in selected_tiles: + prop = '"properties":{{"x":{},"y":{},"z":{}}}'.format(int(tile.x), int(tile.y), int(tile.z)) + geom = '"geometry":{}'.format(json.dumps(feature(tile, precision=6)["geometry"])) + fp.write('{}{{"type":"Feature",{},{}}}'.format("," if not first else "", geom, prop)) + first = False + fp.write("]}") diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 97e6aac6..50b4f629 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -12,17 +12,16 @@ class TestSlippyMapTiles(unittest.TestCase): images = "tests/fixtures/images/" def test_len(self): - dataset = SlippyMapTiles(TestSlippyMapTiles.images) + dataset = SlippyMapTiles(TestSlippyMapTiles.images, "image") self.assertEqual(len(dataset), 3) def test_getitem(self): - dataset = SlippyMapTiles(TestSlippyMapTiles.images) + dataset = SlippyMapTiles(TestSlippyMapTiles.images, "image") image, tile = dataset[0] assert tile == mercantile.Tile(69105, 105093, 18) - # Inspired by: https://github.com/python-pillow/Pillow/blob/master/Tests/test_image.py#L37-L38 - self.assertEqual(repr(image)[:45], "