RosettaCodeData/Task/Mandelbrot-set/Python/mandelbrot-set-9.py

86 lines
3.0 KiB
Python

import jax
jax.config.update("jax_enable_x64", True) # faster on GPU P100 than on GPU T4
import numpy as np
import matplotlib.pyplot as plt
import cupy as cp
import jax.numpy as jnp
import decimal as dc # decimal floating point arithmetic with arbitrary precision
dc.getcontext().prec = 80 # set precision to 80 digits (about 256 bits)
d, h = 100, 2000 # pixel density (= image width) and image height
n, r = 100000, 100000.0 # number of iterations and escape radius (r > 2)
a = dc.Decimal("-1.256827152259138864846434197797294538253477389787308085590211144291")
b = dc.Decimal(".37933802890364143684096784819544060002129071484943239316486643285025")
S = np.zeros(n+1, dtype=np.complex128)
u, v = dc.Decimal(0), dc.Decimal(0)
for i in range(n+1):
S[i] = float(u) + float(v) * 1j
if u * u + v * v < r * r:
u, v = u * u - v * v + a, 2 * u * v + b
else:
print("The reference sequence diverges within %s iterations." % i)
break
x = np.linspace(0, 2, num=d+1, dtype=np.float64)
y = np.linspace(0, 2 * h / d, num=h+1, dtype=np.float64)
A, B = np.meshgrid(x * np.pi, y * np.pi)
C = (- 8.0) * np.exp((A + B * 1j) * 1j)
def iteration_cupy(S, C):
def iteration(S, C):
I = cp.zeros(C.shape, dtype=np.intp)
E, Z, dZ = cp.zeros_like(C), cp.zeros_like(C), cp.zeros_like(C)
for i in range(n):
M = cp.absolute(Z) < cp.absolute(E) # rebase when z is closer to zero
I, E = cp.where(M, 0, I), cp.where(M, Z, E) # reset reference orbit
M = cp.absolute(Z) < r
I, E = cp.where(M, I + 1, I), cp.where(M, (2 * S[I] + E) * E + C, E)
Z, dZ = cp.where(M, S[I] + E, Z), cp.where(M, 2 * Z * dZ + 1, dZ)
return I, E, Z, dZ
I, E, Z, dZ = iteration(cp.asarray(S), cp.asarray(C))
return I.get(), E.get(), Z.get(), dZ.get()
def iteration_jax(S, C):
def iteration(S, C):
I = jnp.zeros(C.shape, dtype=np.intp)
E, Z, dZ = jnp.zeros_like(C), jnp.zeros_like(C), jnp.zeros_like(C)
def abs2(z):
return z.real * z.real + z.imag * z.imag
def iterate(i, V):
I, E, Z, dZ = V
M = abs2(Z) < abs2(E) # rebase when z is closer to zero
I, E = jnp.where(M, 0, I), jnp.where(M, Z, E) # reset reference orbit
M = abs2(Z) < abs2(r)
I, E = jnp.where(M, I + 1, I), jnp.where(M, (2 * S[I] + E) * E + C, E)
Z, dZ = jnp.where(M, S[I] + E, Z), jnp.where(M, 2 * Z * dZ + 1, dZ)
return I, E, Z, dZ
I, E, Z, dZ = jax.lax.fori_loop(0, n, iterate, (I, E, Z, dZ), unroll=10)
return I, E, Z, dZ
I, E, Z, dZ = iteration(jnp.asarray(S), jnp.asarray(C))
return np.asarray(I), np.asarray(E), np.asarray(Z), np.asarray(dZ)
I, E, Z, dZ = iteration_jax(S, C) # use iteration_cupy or iteration_jax
D = np.zeros(C.shape, dtype=np.float64)
N = abs(Z) > 2 # exterior distance estimation
D[N] = np.log(abs(Z[N])) * abs(Z[N]) / abs(dZ[N])
plt.imshow(D.T ** 0.015, cmap=plt.cm.gist_ncar, origin="lower")
plt.savefig("Mercator_Mandelbrot_deep_map.png", dpi=200)