86 lines
3.0 KiB
Python
86 lines
3.0 KiB
Python
import jax
|
|
jax.config.update("jax_enable_x64", True) # faster on GPU P100 than on GPU T4
|
|
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
|
|
import cupy as cp
|
|
import jax.numpy as jnp
|
|
|
|
import decimal as dc # decimal floating point arithmetic with arbitrary precision
|
|
dc.getcontext().prec = 80 # set precision to 80 digits (about 256 bits)
|
|
|
|
d, h = 100, 2000 # pixel density (= image width) and image height
|
|
n, r = 100000, 100000.0 # number of iterations and escape radius (r > 2)
|
|
|
|
a = dc.Decimal("-1.256827152259138864846434197797294538253477389787308085590211144291")
|
|
b = dc.Decimal(".37933802890364143684096784819544060002129071484943239316486643285025")
|
|
|
|
S = np.zeros(n+1, dtype=np.complex128)
|
|
u, v = dc.Decimal(0), dc.Decimal(0)
|
|
|
|
for i in range(n+1):
|
|
S[i] = float(u) + float(v) * 1j
|
|
if u * u + v * v < r * r:
|
|
u, v = u * u - v * v + a, 2 * u * v + b
|
|
else:
|
|
print("The reference sequence diverges within %s iterations." % i)
|
|
break
|
|
|
|
x = np.linspace(0, 2, num=d+1, dtype=np.float64)
|
|
y = np.linspace(0, 2 * h / d, num=h+1, dtype=np.float64)
|
|
|
|
A, B = np.meshgrid(x * np.pi, y * np.pi)
|
|
C = (- 8.0) * np.exp((A + B * 1j) * 1j)
|
|
|
|
def iteration_cupy(S, C):
|
|
|
|
def iteration(S, C):
|
|
I = cp.zeros(C.shape, dtype=np.intp)
|
|
E, Z, dZ = cp.zeros_like(C), cp.zeros_like(C), cp.zeros_like(C)
|
|
|
|
for i in range(n):
|
|
M = cp.absolute(Z) < cp.absolute(E) # rebase when z is closer to zero
|
|
I, E = cp.where(M, 0, I), cp.where(M, Z, E) # reset reference orbit
|
|
M = cp.absolute(Z) < r
|
|
I, E = cp.where(M, I + 1, I), cp.where(M, (2 * S[I] + E) * E + C, E)
|
|
Z, dZ = cp.where(M, S[I] + E, Z), cp.where(M, 2 * Z * dZ + 1, dZ)
|
|
|
|
return I, E, Z, dZ
|
|
|
|
I, E, Z, dZ = iteration(cp.asarray(S), cp.asarray(C))
|
|
return I.get(), E.get(), Z.get(), dZ.get()
|
|
|
|
def iteration_jax(S, C):
|
|
|
|
def iteration(S, C):
|
|
I = jnp.zeros(C.shape, dtype=np.intp)
|
|
E, Z, dZ = jnp.zeros_like(C), jnp.zeros_like(C), jnp.zeros_like(C)
|
|
|
|
def abs2(z):
|
|
return z.real * z.real + z.imag * z.imag
|
|
|
|
def iterate(i, V):
|
|
I, E, Z, dZ = V
|
|
M = abs2(Z) < abs2(E) # rebase when z is closer to zero
|
|
I, E = jnp.where(M, 0, I), jnp.where(M, Z, E) # reset reference orbit
|
|
M = abs2(Z) < abs2(r)
|
|
I, E = jnp.where(M, I + 1, I), jnp.where(M, (2 * S[I] + E) * E + C, E)
|
|
Z, dZ = jnp.where(M, S[I] + E, Z), jnp.where(M, 2 * Z * dZ + 1, dZ)
|
|
return I, E, Z, dZ
|
|
|
|
I, E, Z, dZ = jax.lax.fori_loop(0, n, iterate, (I, E, Z, dZ), unroll=10)
|
|
return I, E, Z, dZ
|
|
|
|
I, E, Z, dZ = iteration(jnp.asarray(S), jnp.asarray(C))
|
|
return np.asarray(I), np.asarray(E), np.asarray(Z), np.asarray(dZ)
|
|
|
|
I, E, Z, dZ = iteration_jax(S, C) # use iteration_cupy or iteration_jax
|
|
D = np.zeros(C.shape, dtype=np.float64)
|
|
|
|
N = abs(Z) > 2 # exterior distance estimation
|
|
D[N] = np.log(abs(Z[N])) * abs(Z[N]) / abs(dZ[N])
|
|
|
|
plt.imshow(D.T ** 0.015, cmap=plt.cm.gist_ncar, origin="lower")
|
|
plt.savefig("Mercator_Mandelbrot_deep_map.png", dpi=200)
|