High quality remote sensing data are becoming increasingly available across the globe. Making sense of this huge amount of data is a job that many companies are working hard to solve from many different angles. One piece of amazing open source technology that helps with this problem is TorchGeo, which aims to fill an important part of the gap between deep learning and remote sensing.
If you are not already familiar, TorchGeo is built within the PyTorch architecture, with the specific goal of making data loading and transformations of remote sensing data easy. Handling everything from geographical projections, to indexing and sampling raster files for training and validation. In this post we’re going to dive into using TorchGeo to automatically generate segmentation masks for training using a spatial dataset.
This post is the first part of a multi-part series exploring various aspects of an instance segmentation model we built for recognizing solar panels from remote sensing imagery.
The Trouble With Remote Sensing Data for Computer Vision
In a more typical computer vision problem, samples are naturally discretized by having a single image per sample, and (at most) a few labels per sample. This makes sense when you have a dataset made up of many small images containing the objects you are looking to classify/detect/segment.
Remote sensing imagery on the other hand has some significant differences which can challenge typical computer vision workflows:
- Few images, many labels
Instead of many small images containing just a few labels, remote sensing projects usually have just a few images covering hundreds of square kilometers. Each of these images can contain thousands of labels. Adequate sampling to train a model requires a well thought out strategy in this case.
- Geo data formats
Most computer vision algorithms are designed to be used with typical image formats, such as jpeg and png. Labels used in those algorithms are typically pixel referenced (e.g. the coordinate origin is located at one corner of the image). For remote sensing imagery, data are usually georeferenced to some real world coordinate system, and stored in spatial raster formats like geotiffs. There are many mature free and open source solutions written in Python for working with geo datasets, but implementing these can quickly complicate a computer vision system which is not natively built for handling remote sensing imagery.
How TorchGeo helps
TorchGeo provides PyTorch native infrastructure for working with remote sensing data.
GeoDatasets
TorchGeo ships with something called a GeoDataset. These GeoDatasets can wrap a folder full of georeferenced raster files or vector files containing your labels. When the GeoDataset is instantiated, it looks at all files in that folder (or set of folders), and for each file records the bounding box in an Rtree index. This allows the GeoDataset to act as an efficient and organized data catalog, which is built to be easily sampled for training a PyTorch model.
GeoDatasets come pre-baked in a number of different flavors. There are RasterDatasets which are used for spatial raster datasets, as well as VectorDatasets which are used for handling geospatial labels.
In practice, here’s a basic example of how to set up a TorchGeo RasterDataset for handling spatial raster files:
from torchgeo.datasets import RasterDataset
# set up the raster dataset class
Class GeoTiffDataset(RasterDataset)
filename_glob = "*.tif"
# initialize the raster dataset
raster_data = GeoTiffDataset(.
Paths = [
"/path/to/folder/with/geotif/files",
"/path/to/another/folder/with/geotif/files",
]
)
This GeoDataset will natively load rasters as PyTorch Tensors which contain a single dimension for each channel in the data (allowing training on more than just 3 channel RGB data).
The dataset is initialized with a spatial index containing a bounding box for each raster file. When it comes time to create training samples, the index is used to locate samples within the dataset. This means that sampling of the entire spatial dataset can be done without needing to do any mosaicing or merging of raster files.
Projections are handled natively by TorchGeo, so that each raster file can be kept in its native projection (although at training time projecting rasters on the fly could add processing overhead).
Similarly, georeferenced labels can be included using TorchGeo’s VectorDatasets. Here’s an example using a set of GeoPackage files containing polygons to be used as labels:
from torchgeo.datasets import VectorDataset
# set up the label dataset class
Class LabelDataset(VectorDataset)
filename_glob = "*labels.gpkg"
# initialize the label dataset
label_data = LabelDataset(
paths=[
"/path/to/labels"
]
)
The labels are organized similarly to the raster files with an RTree index containing the bounding boxes of each file in the dataset.
Combining Label datasets and raster datasets
TorchGeo datasets are flexible enough to allow you to very easily combine multiple datasets using either an intersection or union strategy. In the case of combining labels with rasters you can use an intersection of both datasets which will give you a single dataset that can be easily sampled in later steps.
Creating an intersectional dataset is as easy as using the &
operator to join two datasets together like this:
training_data = raster_data & label_data
The resulting dataset training_data
will automatically treat the VectorDataset component of the intersection as labels, and the RasterDataset component as the image source. This means you can directly sample from training_data
without needing to worry about manually creating your target for each sample.
This automatic creation of the target is a significant benefit of using TorchGeo datasets. Targets like segmentation masks are automatically created for you from the polygons in the label dataset, in the format expected by most of the torchvision models you might use. This means that there is no need to spend time learning the specific formats of the masks expected by TorchVision models, or to get caught debugging bespoke algorithms for turning polygons into segmentation masks while sampling. In our experience this not only speeds up the integration of new data into the training set, it also completely eliminates a significant amount of the complication from the sampling code.
Geo-based samplers
In this context, sampling refers to extracting image chips from the dataset. For example, to train the model on 256 x 256 pixel image chips, we will need to extract many 256 x 256 pixel image chips from the dataset.
Deciding on an adequate strategy for sampling of remote sensing imagery can be complicated. Raster datasets can cover very large areas, and as highlighted earlier, labels might be comparatively small and sparsely distributed over the raster area. A good sampling strategy has to account for things like:
- The balancing of samples per label type.
- Inclusion of negative samples and related balancing.
TorchGeo provides Geo samplers out of the box, which can make things easy, although for more complicated scenarios you may need to extend these methods to implement a custom strategy.
Here’s an example taking random samples from the training_data
dataset using TorchGeo’s RandomGeoSampler:
from torchgeo.samplers import RandomGeoSampler
sampler = RandomGeoSampler(training_data, size=256, length=10000)
This sampler will allow us to extract 10,000 256 x 256 pixel image chips from random locations within the area of intersection between the label dataset and the raster dataset.
Putting It All Together in a PyTorch Lightning Data Module
Each of the TorchGeo objects built above can be combined within a Data Module, and very easily used to train a TorchVision model using Pytorch Lightning.
from torchgeo.datasets import RasterDataset, VectorDataset
from torchgeo.datamodules import GeoDataModule
IMG_SIZE = 256
BATCH_SIZE = 16
SAMPLE_SIZE = 10000
WORKERS = 4
# set up the raster and label dataset classes
Class GeoTiffDataset(RasterDataset)
filename_glob = "*.tif"
Class LabelDataset(VectorDataset)
filename_glob = "*labels.gpkg"
# initialize the raster and label datasets
raster_data = GeoTiffDataset(.
Paths = [
"/path/to/geotiff/files",
"/path/to/more/geotiff/files",
]
)
label_data = LabelDataset(
paths=[
"/path/to/labels/geopackage/files"
]
)
# Create an intersection dataset from rasters and labels
training_data = raster_data & label_data
# initialize the lightning data module
datamodule = GeoDataModule(
dataset_class = type(dataset), # intersection dataset class
batch_size = BATCH_SIZE,
patch_size = IMG_SIZE,
length = SAMPLE_SIZE,
num_workers = WORKERS,
dataset1 = raster_data, # passed to the Intersection dataset init
dataset2 = label_data, # passed to the Intersection dataset init
)
The above code snippet is a highly simplified example which would almost certainly need to be customized for any real use case. For example, the above code does not explicitly show how to implement transformations, but you can implement transforms in the exact same way suggested in the Pytorch Lightning docs.
Note that the data module itself abstracts away the geo-sampler being used. This is good for this simple example, but for something more complex you can customize the sampling strategy by overriding the data modules setup()
method.
Using this generalized TorchGeo framework gives a clean, organized, and consistent way to train almost any kind of pytorch model using remote sensing data.
The Solar Panel dataset
We at Sparkgeo used the above framework for training a computer vision model to find and segment solar panels in remote sensing imagery. In our case the dataset we used is publicly available, 30cm resolution RGB imagery collected over 4 cities in California. The labels in the dataset are geo-referenced polygons outlining all solar panels within the spatial bounds of the imagery data.
One of the main hurdles in using these data was related to the fact that labels are sparsely distributed (and sometimes clustered) over the raster area. To handle this, we built a custom sampler that ensured our image chips adequately balanced positive and negative samples.
Handling this balancing just right proved to be an important step in training a model that was able to adequately differentiate between things that looked like solar panels (parking lots with particular patterns painted on them for example), and actual solar panels.
The image above shows an example of a sampled image chip, with the segmentation masks (color shaded areas on top of solar panels) automatically generated from the label polygons.
Conclusion
TorchGeo simplifies a number of the complications that come up when training a PyTorch model on remote sensing data. We had a positive experience using it in our solar panel computer vision project, and will likely use it more in the future.