Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add Interpolate transform #2

Merged
merged 18 commits into from
Sep 19, 2023
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
Clustering = "aaaa29a8-35af-508c-8bc3-b662a17a0fe5"
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
GeoStatsModels = "ad987403-13c5-47b5-afee-0a48f6ac4f12"
GeoTables = "e502b557-6362-48c1-8219-d30d308dcdb0"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Meshes = "eacbb407-ea5a-433e-ab97-5258b1ca43fa"
Expand All @@ -25,6 +26,7 @@ CategoricalArrays = "0.10"
Clustering = "0.15"
Combinatorics = "1.0"
Distances = "0.10"
GeoStatsModels = "0.1"
GeoTables = "1.6"
Meshes = "0.35"
ScientificTypes = "3.0"
Expand Down
5 changes: 5 additions & 0 deletions src/GeoStatsTransforms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ module GeoStatsTransforms

using Meshes
using GeoTables
using GeoStatsModels

using Tables
using TableDistances
Expand All @@ -20,6 +21,7 @@ using SparseArrays
using LinearAlgebra
using Statistics

import GeoStatsModels: GeoStatsModel, fit, predict, predictprob
eliascarv marked this conversation as resolved.
Show resolved Hide resolved
import TableTransforms: ColSpec, Col, AllSpec, NoneSpec
import TableTransforms: colspec, choose
import TableTransforms: divide, attach
Expand All @@ -39,8 +41,11 @@ include("rasterize.jl")
include("potrace.jl")
include("detrend.jl")

include("interpolate.jl")

eliascarv marked this conversation as resolved.
Show resolved Hide resolved
export
# transforms
Interpolate,
UniqueCoords,
Rasterize,
Potrace,
Expand Down
168 changes: 168 additions & 0 deletions src/interpolate.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
# ------------------------------------------------------------------
# Licensed under the MIT License. See LICENSE in the project root.
# ------------------------------------------------------------------

"""
Interpolate(
domain,
vars₁ => model₁, ..., varsₙ => modelₙ;
minneighbors=1,
maxneighbors=10,
neighborhood=nothing,
distance=Euclidean(),
path=LinearPath()
point=true,
prob=false
)

TODO
eliascarv marked this conversation as resolved.
Show resolved Hide resolved

## Global Parameters
eliascarv marked this conversation as resolved.
Show resolved Hide resolved

* `minneighbors` - Minimum number of neighbors (default to `1`)
* `maxneighbors` - Maximum number of neighbors (default to `10`)
* `neighborhood` - Search neighborhood (default to `nothing`)
* `distance` - A distance defined in Distances.jl (default to `Euclidean()`)
* `exponent` - Exponent of the distances (default to `1`)
eliascarv marked this conversation as resolved.
Show resolved Hide resolved
* `path` - The path algorithm used to iterate over the domain (default to `LinearPath()`)
eliascarv marked this conversation as resolved.
Show resolved Hide resolved

The `maxneighbors` option can be used to perform inverse distance weighting
eliascarv marked this conversation as resolved.
Show resolved Hide resolved
with a subset of measurements per prediction location. If `maxneighbors`
is not provided, then all measurements are used.

Two `neighborhood` search methods are available:

* If a `neighborhood` is provided, local prediction is performed
by sliding the `neighborhood` in the domain.

* If a `neighborhood` is not provided, the prediction is performed
using `maxneighbors` nearest neighbors according to `distance`.
"""
struct Interpolate{D<:Domain,N,M,P} <: TableTransform
domain::D
colspecs::Vector{ColSpec}
models::Vector{GeoStatsModel}
minneighbors::Int
maxneighbors::Int
neighborhood::N
distance::M
path::P
point::Bool
prob::Bool
end

Interpolate(
domain::Domain,
colspecs,
models;
minneighbors=1,
maxneighbors=10,
neighborhood=nothing,
distance=Euclidean(),
path=LinearPath(),
point=true,
prob=false
) = Interpolate(
domain,
collect(ColSpec, colspecs),
collect(GeoStatsModel, models),
minneighbors,
maxneighbors,
neighborhood,
distance,
path,
point,
prob
)

Interpolate(domain::Domain; distance=Euclidean(), kwargs...) =
Interpolate(domain, [AllSpec()], [IDW(1, distance)]; distance, kwargs...)

Interpolate(domain::Domain, pairs::Pair{<:Any,<:GeoStatsModel}...; kwargs...) =
Interpolate(domain, colspec.(first.(pairs)), last.(pairs); kwargs...)

isrevertible(::Type{<:Interpolate}) = false

function apply(transform::Interpolate, geotable::AbstractGeoTable)
dom = domain(geotable)
tab = values(geotable)
cols = Tables.columns(tab)
vars = Tables.columnnames(cols)

idom = transform.domain
colspecs = transform.colspecs
models = transform.models
minneighbors = transform.minneighbors
maxneighbors = transform.maxneighbors
neighborhood = transform.neighborhood
distance = transform.distance
path = transform.path
point = transform.point
prob = transform.prob

nobs = nrow(geotable)
if maxneighbors > nobs || maxneighbors < 1
@warn "Invalid maximum number of neighbors. Adjusting to $nobs..."
maxneighbors = nobs
end

if minneighbors > maxneighbors || minneighbors < 1
@warn "Invalid minimum number of neighbors. Adjusting to 1..."
minneighbors = 1
end

vdom = point ? PointSet(centroid(dom, i) for i in 1:nobs) : dom
eliascarv marked this conversation as resolved.
Show resolved Hide resolved
searcher = searcher_ui(vdom, maxneighbors, distance, neighborhood)

# preprocess variable models
varmodels = mapreduce(vcat, colspecs, models) do colspec, model
svars = choose(colspec, vars)
svars .=> Ref(model)
end

# pre-allocate memory for neighbors
neighbors = Vector{Int}(undef, maxneighbors)

# prediction order
inds = traverse(idom, path)

# predict variable values
function pred(var, model)
map(inds) do ind
# centroid of estimation
center = centroid(idom, ind)

# find neighbors with data
nneigh = search!(neighbors, center, searcher)

# skip if there are too few neighbors
if nneigh < minneighbors
missing
else
# final set of neighbors
ninds = view(neighbors, 1:nneigh)

# view neighborhood with data
samples = view(geotable, ninds)
eliascarv marked this conversation as resolved.
Show resolved Hide resolved

# fit model to data
fmodel = fit(model, samples)

# save prediction
geom = point ? center : dom[ind]
if prob
predictprpb(fmodel, var, geom)
eliascarv marked this conversation as resolved.
Show resolved Hide resolved
else
predict(fmodel, var, geom)
end
end
end
end

pairs = (var => pred(var, model) for (var, model) in varmodels)
newtab = (; pairs...) |> Tables.materializer(tab)

newgeotable = georef(newtab, idom)

newgeotable, nothing
end
16 changes: 16 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,19 @@ function _first(x)
vs = skipmissing(x)
isempty(vs) ? missing : first(vs)
end

"""
searcher_ui(domain, maxneighbors, distance, neighborhood)

Return the appropriate search method over the `domain` based on
end-user inputs such as `maxneighbors`, `distance` and `neighborhood`.
"""
function searcher_ui(domain, maxneighbors, distance, neighborhood)
if isnothing(neighborhood)
# nearest neighbor search with a metric
KNearestSearch(domain, maxneighbors; metric=distance)
else
# neighbor search with ball neighborhood
KBallSearch(domain, maxneighbors, neighborhood)
end
end
Loading
Loading