Source code for xrspatial.pathfinding

import xarray as xr
import numpy as np

from xrspatial.utils import ngjit
from xrspatial.utils import get_dataarray_resolution

from typing import Union, Optional

import warnings


NONE = -1


def _get_pixel_id(point, raster, xdim=None, ydim=None):
    # get location in `raster` pixel space for `point` in y-x coordinate space
    # point: (y, x) - coordinates of the point
    # xdim: name of the x coordinate dimension in input `raster`.
    # ydim: name of the x coordinate dimension in input `raster`

    if ydim is None:
        ydim = raster.dims[-2]
    if xdim is None:
        xdim = raster.dims[-1]
    y_coords = raster.coords[ydim].data
    x_coords = raster.coords[xdim].data

    cellsize_x, cellsize_y = get_dataarray_resolution(raster, xdim, ydim)
    py = int(abs(point[0] - y_coords[0]) / cellsize_y)
    px = int(abs(point[1] - x_coords[0]) / cellsize_x)

    # return index of row and column where the `point` located.
    return py, px


@ngjit
def _is_not_crossable(cell_value, barriers):
    # nan cell is not walkable
    if np.isnan(cell_value):
        return True

    for i in barriers:
        if cell_value == i:
            return True
    return False


@ngjit
def _is_inside(py, px, h, w):
    inside = True
    if px < 0 or px >= w:
        inside = False
    if py < 0 or py >= h:
        inside = False
    return inside


@ngjit
def _distance(x1, y1, x2, y2):
    # euclidean distance in pixel space from (y1, x1) to (y2, x2)
    return np.sqrt((x1 - x2)**2 + (y1 - y2)**2)


@ngjit
def _heuristic(x1, y1, x2, y2):
    # heuristic to estimate distance between 2 point
    # TODO: what if we want to use another distance metric?
    return _distance(x1, y1, x2, y2)


@ngjit
def _min_cost_pixel_id(cost, is_open):
    height, width = cost.shape
    py = NONE
    px = NONE
    # set min cost to a very big number
    # this value is only an estimation
    min_cost = (height + width) ** 2
    for i in range(height):
        for j in range(width):
            if is_open[i, j] and cost[i, j] < min_cost:
                min_cost = cost[i, j]
                py = i
                px = j
    return py, px


@ngjit
def _find_nearest_pixel(py, px, data, barriers):
    # if the cell is already valid, return itself
    if not _is_not_crossable(data[py, px], barriers):
        return py, px

    height, width = data.shape
    # init min distance as max possible distance
    min_distance = _distance(0, 0, height - 1, width - 1)
    # return of the function
    nearest_y = NONE
    nearest_x = NONE
    for y in range(height):
        for x in range(width):
            if not _is_not_crossable(data[y, x], barriers):
                d = _distance(x, y, px, py)
                if d < min_distance:
                    min_distance = d
                    nearest_y = y
                    nearest_x = x

    return nearest_y, nearest_x


@ngjit
def _reconstruct_path(path_img, parent_ys, parent_xs, cost,
                      start_py, start_px, goal_py, goal_px):
    # construct path output image as a 2d array with NaNs for non-path pixels,
    # and the value of the path pixels being the current cost up to that point
    current_x = goal_px
    current_y = goal_py

    if parent_xs[current_y, current_x] != NONE and \
            parent_ys[current_y, current_x] != NONE:
        # exist path from start to goal
        # add cost at start
        path_img[start_py, start_px] = cost[start_py, start_px]
        # add cost along the path
        while current_x != start_px or current_y != start_py:
            # value of a path pixel is the cost up to that point
            path_img[current_y, current_x] = cost[current_y, current_x]
            parent_y = parent_ys[current_y, current_x]
            parent_x = parent_xs[current_y, current_x]
            current_y = parent_y
            current_x = parent_x
    return


def _neighborhood_structure(connectivity=8):
    if connectivity == 8:
        # 8-connectivity
        neighbor_xs = [-1, -1, -1, 0, 0, 1, 1, 1]
        neighbor_ys = [-1, 0, 1, -1, 1, -1, 0, 1]
    else:
        # 4-connectivity
        neighbor_ys = [0, -1, 1, 0]
        neighbor_xs = [-1, 0, 0, 1]
    return np.array(neighbor_ys), np.array(neighbor_xs)


@ngjit
def _a_star_search(data, path_img, start_py, start_px, goal_py, goal_px,
                   barriers, neighbor_ys, neighbor_xs):

    height, width = data.shape
    # parent of the (i, j) pixel is the pixel at
    # (parent_ys[i, j], parent_xs[i, j])
    # first initialize parent of all cells as invalid (NONE, NONE)
    parent_ys = np.ones((height, width), dtype=np.int64) * NONE
    parent_xs = np.ones((height, width), dtype=np.int64) * NONE

    # parent of start is itself
    parent_ys[start_py, start_px] = start_py
    parent_xs[start_py, start_px] = start_px

    # distance from start to the current node
    d_from_start = np.zeros_like(data, dtype=np.float64)
    # total cost of the node: cost = d_from_start + d_to_goal
    # heuristic — estimated distance from the current node to the end node
    cost = np.zeros_like(data, dtype=np.float64)

    # initialize both open and closed list all False
    is_open = np.zeros(data.shape, dtype=np.bool_)
    is_closed = np.zeros(data.shape, dtype=np.bool_)

    if not _is_not_crossable(data[start_py, start_px], barriers):
        # if start node is crossable
        # add the start node to open list
        is_open[start_py, start_px] = True
        # init cost at start location
        d_from_start[start_py, start_px] = 0
        cost[start_py, start_px] = d_from_start[start_py, start_px] + \
            _heuristic(start_px, start_py, goal_px, goal_py)

    num_open = np.sum(is_open)
    while num_open > 0:
        py, px = _min_cost_pixel_id(cost, is_open)
        # pop current node off open list, add it to closed list
        is_open[py][px] = 0
        is_closed[py][px] = True
        # found the goal
        if (py, px) == (goal_py, goal_px):
            # reconstruct path
            _reconstruct_path(path_img, parent_ys, parent_xs,
                              d_from_start, start_py, start_px,
                              goal_py, goal_px)
            return

        # visit neighborhood
        for y, x in zip(neighbor_ys, neighbor_xs):
            neighbor_y = py + y
            neighbor_x = px + x

            # neighbor is within the surface image
            if neighbor_y > height - 1 or neighbor_y < 0 \
                    or neighbor_x > width - 1 or neighbor_x < 0:
                continue

            # walkable
            if _is_not_crossable(data[neighbor_y][neighbor_x], barriers):
                continue

            # check if neighbor is in the closed list
            if is_closed[neighbor_y, neighbor_x]:
                continue

            # distance from start to this neighbor
            d = d_from_start[py, px] + _distance(px, py,
                                                 neighbor_x, neighbor_y)
            # if neighbor is already in the open list
            if is_open[neighbor_y, neighbor_x] and \
                    d > d_from_start[neighbor_y, neighbor_x]:
                continue

            # calculate cost
            d_from_start[neighbor_y, neighbor_x] = d
            d_to_goal = _heuristic(neighbor_x, neighbor_y, goal_px, goal_py)
            cost[neighbor_y, neighbor_x] = \
                d_from_start[neighbor_y, neighbor_x] + d_to_goal
            # add neighbor to the open list
            is_open[neighbor_y, neighbor_x] = True
            parent_ys[neighbor_y, neighbor_x] = py
            parent_xs[neighbor_y, neighbor_x] = px

        num_open = np.sum(is_open)
    return