107 lines
2.9 KiB
Scala
107 lines
2.9 KiB
Scala
import java.math.BigInteger
|
|
import scala.collection.immutable.List
|
|
import scala.annotation.tailrec
|
|
|
|
object TonelliShanks {
|
|
private val ZERO = BigInteger.ZERO
|
|
private val ONE = BigInteger.ONE
|
|
private val TEN = BigInteger.TEN
|
|
private val TWO = BigInteger.valueOf(2)
|
|
private val FOUR = BigInteger.valueOf(4)
|
|
|
|
private case class Solution(root1: BigInteger, root2: BigInteger, exists: Boolean)
|
|
|
|
private def ts(n: Long, p: Long): Solution = ts(BigInteger.valueOf(n), BigInteger.valueOf(p))
|
|
|
|
private def ts(n: BigInteger, p: BigInteger): Solution = {
|
|
val powModP: (BigInteger, BigInteger) => BigInteger = (a, e) => a.modPow(e, p)
|
|
val ls: BigInteger => BigInteger = a => powModP(a, p.subtract(ONE).divide(TWO))
|
|
|
|
if (!ls(n).equals(ONE)) return Solution(ZERO, ZERO, false)
|
|
|
|
var q = p.subtract(ONE)
|
|
var ss = ZERO
|
|
while (q.and(ONE).equals(ZERO)) {
|
|
ss = ss.add(ONE)
|
|
q = q.shiftRight(1)
|
|
}
|
|
|
|
if (ss.equals(ONE)) {
|
|
val r1 = powModP(n, p.add(ONE).divide(FOUR))
|
|
return Solution(r1, p.subtract(r1), true)
|
|
}
|
|
|
|
var z = TWO
|
|
while (!ls(z).equals(p.subtract(ONE))) z = z.add(ONE)
|
|
var c = powModP(z, q)
|
|
var r = powModP(n, q.add(ONE).divide(TWO))
|
|
var t = powModP(n, q)
|
|
var m = ss
|
|
|
|
// Convert the while(true) loop to a tail-recursive function
|
|
@tailrec
|
|
def loop(r: BigInteger, c: BigInteger, t: BigInteger, m: BigInteger): Solution = {
|
|
if (t.equals(ONE)) {
|
|
Solution(r, p.subtract(r), true)
|
|
} else {
|
|
var i = ZERO
|
|
var zz = t
|
|
while (!zz.equals(BigInteger.ONE) && i.compareTo(m.subtract(ONE)) < 0) {
|
|
zz = zz.multiply(zz).mod(p)
|
|
i = i.add(ONE)
|
|
}
|
|
var b = c
|
|
var e = m.subtract(i).subtract(ONE)
|
|
while (e.compareTo(ZERO) > 0) {
|
|
b = b.multiply(b).mod(p)
|
|
e = e.subtract(ONE)
|
|
}
|
|
val newR = r.multiply(b).mod(p)
|
|
val newC = b.multiply(b).mod(p)
|
|
val newT = t.multiply(newC).mod(p)
|
|
val newM = i
|
|
loop(newR, newC, newT, newM)
|
|
}
|
|
}
|
|
|
|
loop(r, c, t, m)
|
|
}
|
|
|
|
def main(args: Array[String]): Unit = {
|
|
val pairs = List(
|
|
(10L, 13L),
|
|
(56L, 101L),
|
|
(1030L, 10009L),
|
|
(1032L, 10009L),
|
|
(44402L, 100049L),
|
|
(665820697L, 1000000009L),
|
|
(881398088036L, 1000000000039L)
|
|
)
|
|
|
|
for ((n, p) <- pairs) {
|
|
val sol = ts(n, p)
|
|
println(s"n = $n")
|
|
println(s"p = $p")
|
|
if (sol.exists) {
|
|
println(s"root1 = ${sol.root1}")
|
|
println(s"root2 = ${sol.root2}")
|
|
} else {
|
|
println("No solution exists")
|
|
}
|
|
println()
|
|
}
|
|
|
|
val bn = new BigInteger("41660815127637347468140745042827704103445750172002")
|
|
val bp = TEN.pow(50).add(BigInteger.valueOf(577))
|
|
val sol = ts(bn, bp)
|
|
println(s"n = $bn")
|
|
println(s"p = $bp")
|
|
if (sol.exists) {
|
|
println(s"root1 = ${sol.root1}")
|
|
println(s"root2 = ${sol.root2}")
|
|
} else {
|
|
println("No solution exists")
|
|
}
|
|
}
|
|
}
|