shamir/shamir.go

182 lines
3.6 KiB
Go
Raw Normal View History

2024-07-29 14:45:43 +02:00
package shamir
import (
"crypto/rand"
"fmt"
)
type polynomial struct {
coefficients []uint8
}
func newPolynomial(intercept, degree uint8) (*polynomial, error) {
p := &polynomial{
coefficients: make([]uint8, degree+1),
}
p.coefficients[0] = intercept
if _, err := rand.Read(p.coefficients[1:]); err != nil {
return nil, err
}
return p, nil
}
func (p *polynomial) evaluate(x uint8) uint8 {
if x == 0 {
return p.coefficients[0]
}
out := p.coefficients[len(p.coefficients)-1]
for i := len(p.coefficients) - 2; i >= 0; i-- {
out = add(mult(out, x), p.coefficients[i])
}
return out
}
// Split divides the secret into parts shares with a threshold of minimum shares to reconstruct the secret.
func Split(secret []byte, N, T int) ([][]byte, error) {
if N < T {
return nil, fmt.Errorf("parts cannot be less than threshold")
}
if N > 255 {
return nil, fmt.Errorf("parts cannot exceed 255")
}
if T < 2 {
return nil, fmt.Errorf("threshold must be at least 2")
}
if T > 255 {
return nil, fmt.Errorf("threshold cannot exceed 255")
}
if len(secret) == 0 {
return nil, fmt.Errorf("cannot split an empty secret")
}
// Generate unique x-coordinates for each share
xCoordinates := make([]uint8, N)
for i := 0; i < N; i++ {
xCoordinates[i] = uint8(i + 1)
}
// Initialize shares with the secret length + 1 (for the x-coordinate)
shares := make([][]byte, N)
for i := range shares {
shares[i] = make([]byte, len(secret)+1)
shares[i][len(secret)] = xCoordinates[i]
}
// Create a polynomial for each byte in the secret and evaluate it at each x-coordinate
for i, b := range secret {
p, err := newPolynomial(b, uint8(T-1))
if err != nil {
return nil, err
}
for j := 0; j < N; j++ {
shares[j][i] = p.evaluate(xCoordinates[j])
}
}
return shares, nil
}
// Combine reconstructs the secret from the provided shares.
func Combine(shares [][]byte) ([]byte, error) {
if len(shares) < 2 {
return nil, fmt.Errorf("less than two shares cannot be used to reconstruct the secret")
}
shareLength := len(shares[0])
if shareLength < 2 {
return nil, fmt.Errorf("shares must be at least two bytes long")
}
for _, share := range shares {
if len(share) != shareLength {
return nil, fmt.Errorf("all shares must be the same length")
}
}
secret := make([]byte, shareLength-1)
xSamples := make([]uint8, len(shares))
ySamples := make([]uint8, len(shares))
for i, share := range shares {
xSamples[i] = share[shareLength-1]
}
for i := range secret {
for j, share := range shares {
ySamples[j] = share[i]
}
val, err := interpolatePolynomialSafe(xSamples, ySamples, 0)
if err != nil {
return nil, err
}
secret[i] = val
}
return secret, nil
}
func interpolatePolynomialSafe(xSamples, ySamples []uint8, x uint8) (uint8, error) {
result := uint8(0)
for i := range xSamples {
num, denom := uint8(1), uint8(1)
for j := range xSamples {
if i != j {
num = mult(num, add(x, xSamples[j]))
denom = mult(denom, add(xSamples[i], xSamples[j]))
}
}
term, err := div(num, denom)
if err != nil {
return 0, err
}
result = add(result, mult(ySamples[i], term))
}
return result, nil
}
// Helper functions for arithmetic in GF(2^8)
func add(a, b uint8) uint8 {
return a ^ b
}
func mult(a, b uint8) uint8 {
var p uint8
for b > 0 {
if b&1 == 1 {
p ^= a
}
if a&0x80 > 0 {
a = (a << 1) ^ 0x1B
} else {
a <<= 1
}
b >>= 1
}
return p
}
func div(a, b uint8) (uint8, error) {
if b == 0 {
return 0, fmt.Errorf("division by zero")
}
return mult(a, inverse(b)), nil
}
func inverse(a uint8) uint8 {
var b, c uint8
for b = 1; b != 0; b++ {
if mult(a, b) == 1 {
c = b
break
}
}
return c
}