Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 42 additions & 28 deletions ext/BatsrusPyPlotExt/pyplot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,14 @@ function plotgrid(
ax = plt.gca()
end

if bd isa BatsrusIDLStructured || ndims(bd.x) == 3
X, Y = eachslice(bd.x, dims = 3)
# Use pcolormesh with transparent faces to show the grid
X, Y = eachslice(bd.x, dims = ndims(bd.x))

if X isa AbstractMatrix && size(X, 1) > 1 && size(X, 2) > 1
c = ax.pcolormesh(
X, Y, zeros(size(X)...), edgecolors = "k",
facecolors = "none", linewidths = 0.5, kwargs...
)
else
X, Y = eachslice(bd.x, dims = 2)
c = ax.scatter(X, Y, marker = ".", alpha = 0.6, kwargs...)
end

Expand Down Expand Up @@ -531,9 +530,9 @@ function PyPlot.tricontourf(
x, w = bd.x, bd.w
varIndex_ = findindex(bd, var)

X = vec(x[:, :, 1])
Y = vec(x[:, :, 2])
W = vec(w[:, :, varIndex_])
X = vec(selectdim(x, ndims(x), 1))
Y = vec(selectdim(x, ndims(x), 2))
W = vec(selectdim(w, ndims(w), varIndex_))

#TODO This needs improvement.
if !all(isinf.(plotrange))
Expand Down Expand Up @@ -570,9 +569,9 @@ function PyPlot.tricontour(
x, w = bd.x, bd.w
varIndex_ = findindex(bd, var)

X = vec(x[:, :, 1])
Y = vec(x[:, :, 2])
W = vec(w[:, :, varIndex_])
X = vec(selectdim(x, ndims(x), 1))
Y = vec(selectdim(x, ndims(x), 2))
W = vec(selectdim(w, ndims(w), varIndex_))

#TODO This needs improvement.
if !all(isinf.(plotrange))
Expand Down Expand Up @@ -600,8 +599,8 @@ function PyPlot.triplot(
plotrange = [-Inf, Inf, -Inf, Inf],
kwargs...
) where {TV}
X = vec(bd.x[:, :, 1])
Y = vec(bd.x[:, :, 2])
X = vec(selectdim(bd.x, ndims(bd.x), 1))
Y = vec(selectdim(bd.x, ndims(bd.x), 2))
triang = PyPlot.matplotlib.tri.Triangulation(X, Y)
#TODO This needs improvement.
if !all(isinf.(plotrange))
Expand Down Expand Up @@ -630,9 +629,9 @@ function PyPlot.plot_trisurf(
x, w = bd.x, bd.w
varIndex_ = findindex(bd, var)

X = vec(x[:, :, 1])
Y = vec(x[:, :, 2])
W = vec(w[:, :, varIndex_])
X = vec(selectdim(x, ndims(x), 1))
Y = vec(selectdim(x, ndims(x), 2))
W = vec(selectdim(w, ndims(w), varIndex_))

#TODO This needs improvement.
if !all(isinf.(plotrange))
Expand Down Expand Up @@ -736,11 +735,10 @@ function PyPlot.tripcolor(

varIndex_ = findindex(bd, var)

X, Y = eachslice(x, dims = 3)
X, Y = eachslice(x, dims = ndims(x))
Comment thread
henry2004y marked this conversation as resolved.
adjust_plotrange!(plotrange, extrema(X), extrema(Y))
W = vec(w[:, :, varIndex_])
W = vec(selectdim(w, ndims(w), varIndex_))

adjust_plotrange!(plotrange, extrema(X), extrema(Y))
triang = PyPlot.matplotlib.tri.Triangulation(vec(X), vec(Y))

# Mask off unwanted triangles at the inner boundary.
Expand Down Expand Up @@ -850,26 +848,42 @@ function _getvector(
) where {TV}
x, w = bd.x, bd.w
varstream = split(var, ";")
var1_ = findfirst(x -> lowercase(x) == lowercase(varstream[1]), bd.head.wname)
var2_ = findfirst(x -> lowercase(x) == lowercase(varstream[2]), bd.head.wname)
plot_step = isinf(plotinterval) ? (x[end, 1, 1] - x[1, 1, 1]) / size(x, 1) :
plotinterval
var1_ = findindex(bd, varstream[1])
var2_ = findindex(bd, varstream[2])

if bd.head.gencoord # generalized coordinates
X, Y = vec(x[:, :, 1]), vec(x[:, :, 2])
X, Y = vec(selectdim(x, ndims(x), 1)), vec(selectdim(x, ndims(x), 2))
adjust_plotrange!(plotrange, extrema(X), extrema(Y))

# Create grid values first.
xi = range(Float64(plotrange[1]), stop = Float64(plotrange[2]), step = plot_step)
yi = range(Float64(plotrange[3]), stop = Float64(plotrange[4]), step = plot_step)
x1, x2 = Float64(plotrange[1]), Float64(plotrange[2])
y1, y2 = Float64(plotrange[3]), Float64(plotrange[4])

if isinf(plotinterval)
n = max(round(Int, sqrt(length(X))), 100)
xi = range(x1, stop = x2, length = n)
yi = range(y1, stop = y2, length = n)
else
xi = range(x1, stop = x2, step = plotinterval)
yi = range(y1, stop = y2, step = plotinterval)
if length(xi) < 2
xi = range(x1, stop = x2, length = 2)
end
if length(yi) < 2
yi = range(y1, stop = y2, length = 2)
end
end

# Is there a triangulation method in Julia?
tr = PyPlot.matplotlib.tri.Triangulation(X, Y)
Xi, Yi = meshgrid(xi, yi)

interpolator = PyPlot.matplotlib.tri.LinearTriInterpolator(tr, w[:, 1, var1_])
w1 = vec(selectdim(w, ndims(w), var1_))
w2 = vec(selectdim(w, ndims(w), var2_))

interpolator = PyPlot.matplotlib.tri.LinearTriInterpolator(tr, w1)
v1 = interpolator(Xi, Yi)

interpolator = PyPlot.matplotlib.tri.LinearTriInterpolator(tr, w[:, 1, var2_])
interpolator = PyPlot.matplotlib.tri.LinearTriInterpolator(tr, w2)
v2 = interpolator(Xi, Yi)
else # Cartesian coordinates
xrange, yrange = get_range(bd)
Expand Down
4 changes: 2 additions & 2 deletions src/utility.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ function _interp2d_unstructured(
innermask::Bool, rbody::Union{Nothing, Real}, useMatplotlib::Bool
) where {TV, TX, TW}
x = bd.x
X, Y = eachslice(x, dims = 3)
X, Y = eachslice(x, dims = ndims(x))
X, Y = vec(X), vec(Y)
Ws = [vec(W_raw) for W_raw in Ws_raw]

Expand Down Expand Up @@ -221,7 +221,7 @@ function meshgrid(
)
x = bd.x

X, Y = eachslice(x, dims = 3)
X, Y = eachslice(x, dims = ndims(x))
X, Y = vec(X), vec(Y)

adjust_plotrange!(plotrange, extrema(X), extrema(Y))
Expand Down
14 changes: 14 additions & 0 deletions test/tests_plotting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,20 @@ using StaticArrays
plt.close()
end

@testset "Unstructured 2D cut" begin
file = "bx0_mhd_6_t00000100_n00000352.out"
bd = load(joinpath(datapath, file))
plt.figure()
c = plotgrid(bd)
@test c isa PyPlot.PyObject
plt.close()

plt.figure()
c = streamplot(bd, "bx;bz")
@test c isa PyPlot.PyObject
plt.close()
end

# 2. Batl (AMR)
@testset "Batl (AMR)" begin
# Mock Head
Expand Down
Loading