shamir/shamir.go

192 lines
3.9 KiB
Go
Raw Normal View History

2024-07-29 14:45:43 +02:00
package shamir
import (
"crypto/rand"
2024-07-31 16:13:20 +02:00
"errors"
2024-07-29 14:45:43 +02:00
"fmt"
)
2024-07-31 16:13:20 +02:00
var (
ErrPartsLessThanThreshold = errors.New("number of parts cannot be less than the threshold")
ErrPartsExceedLimit = errors.New("number of parts cannot exceed 255")
ErrThresholdTooSmall = errors.New("threshold must be at least 2")
ErrThresholdExceedLimit = errors.New("threshold cannot exceed 255")
ErrEmptySecret = errors.New("cannot split an empty secret")
ErrInsufficientShares = errors.New("less than two shares cannot be used to reconstruct the secret")
ErrSharesTooShort = errors.New("shares must be at least two bytes long")
ErrInconsistentShareLength = errors.New("all shares must be the same length")
ErrDivisionByZero = errors.New("division by zero")
)
2024-07-29 14:45:43 +02:00
type polynomial struct {
2024-07-31 16:13:20 +02:00
coeffs []uint8
2024-07-29 14:45:43 +02:00
}
func newPolynomial(intercept, degree uint8) (*polynomial, error) {
p := &polynomial{
2024-07-31 16:13:20 +02:00
coeffs: make([]uint8, degree+1),
2024-07-29 14:45:43 +02:00
}
2024-07-31 16:13:20 +02:00
p.coeffs[0] = intercept
2024-07-29 14:45:43 +02:00
2024-07-31 16:13:20 +02:00
if _, err := rand.Read(p.coeffs[1:]); err != nil {
return nil, fmt.Errorf("failed to generate random coefficients: %w", err)
2024-07-29 14:45:43 +02:00
}
return p, nil
}
func (p *polynomial) evaluate(x uint8) uint8 {
if x == 0 {
2024-07-31 16:13:20 +02:00
return p.coeffs[0]
2024-07-29 14:45:43 +02:00
}
2024-07-31 16:13:20 +02:00
result := p.coeffs[len(p.coeffs)-1]
for i := len(p.coeffs) - 2; i >= 0; i-- {
result = gfAdd(gfMult(result, x), p.coeffs[i])
2024-07-29 14:45:43 +02:00
}
2024-07-31 16:13:20 +02:00
return result
2024-07-29 14:45:43 +02:00
}
2024-07-31 16:13:20 +02:00
// Split divides the secret into n shares with a threshold t for reconstruction.
func Split(secret []byte, n, t int) ([][]byte, error) {
if n < t {
return nil, ErrPartsLessThanThreshold
2024-07-29 14:45:43 +02:00
}
2024-07-31 16:13:20 +02:00
if n > 255 {
return nil, ErrPartsExceedLimit
2024-07-29 14:45:43 +02:00
}
2024-07-31 16:13:20 +02:00
if t < 2 {
return nil, ErrThresholdTooSmall
2024-07-29 14:45:43 +02:00
}
2024-07-31 16:13:20 +02:00
if t > 255 {
return nil, ErrThresholdExceedLimit
2024-07-29 14:45:43 +02:00
}
if len(secret) == 0 {
2024-07-31 16:13:20 +02:00
return nil, ErrEmptySecret
2024-07-29 14:45:43 +02:00
}
2024-07-31 16:13:20 +02:00
xCoords := make([]uint8, n)
for i := 0; i < n; i++ {
xCoords[i] = uint8(i + 1)
2024-07-29 14:45:43 +02:00
}
2024-07-31 16:13:20 +02:00
shares := make([][]byte, n)
2024-07-29 14:45:43 +02:00
for i := range shares {
shares[i] = make([]byte, len(secret)+1)
2024-07-31 16:13:20 +02:00
shares[i][len(secret)] = xCoords[i]
2024-07-29 14:45:43 +02:00
}
for i, b := range secret {
2024-07-31 16:13:20 +02:00
p, err := newPolynomial(b, uint8(t-1))
2024-07-29 14:45:43 +02:00
if err != nil {
return nil, err
}
2024-07-31 16:13:20 +02:00
for j := 0; j < n; j++ {
shares[j][i] = p.evaluate(xCoords[j])
2024-07-29 14:45:43 +02:00
}
}
return shares, nil
}
2024-07-31 16:13:20 +02:00
// Combine reconstructs the secret from the shares.
2024-07-29 14:45:43 +02:00
func Combine(shares [][]byte) ([]byte, error) {
if len(shares) < 2 {
2024-07-31 16:13:20 +02:00
return nil, ErrInsufficientShares
2024-07-29 14:45:43 +02:00
}
2024-07-31 16:13:20 +02:00
shareLen := len(shares[0])
if shareLen < 2 {
return nil, ErrSharesTooShort
2024-07-29 14:45:43 +02:00
}
for _, share := range shares {
2024-07-31 16:13:20 +02:00
if len(share) != shareLen {
return nil, ErrInconsistentShareLength
2024-07-29 14:45:43 +02:00
}
}
2024-07-31 16:13:20 +02:00
secret := make([]byte, shareLen-1)
2024-07-29 14:45:43 +02:00
xSamples := make([]uint8, len(shares))
ySamples := make([]uint8, len(shares))
for i, share := range shares {
2024-07-31 16:13:20 +02:00
xSamples[i] = share[shareLen-1]
2024-07-29 14:45:43 +02:00
}
for i := range secret {
for j, share := range shares {
ySamples[j] = share[i]
}
2024-07-31 16:13:20 +02:00
val, err := interpolatePolynomial(xSamples, ySamples, 0)
2024-07-29 14:45:43 +02:00
if err != nil {
return nil, err
}
secret[i] = val
}
return secret, nil
}
2024-07-31 16:13:20 +02:00
func interpolatePolynomial(xSamples, ySamples []uint8, x uint8) (uint8, error) {
2024-07-29 14:45:43 +02:00
result := uint8(0)
for i := range xSamples {
num, denom := uint8(1), uint8(1)
for j := range xSamples {
if i != j {
2024-07-31 16:13:20 +02:00
num = gfMult(num, gfAdd(x, xSamples[j]))
denom = gfMult(denom, gfAdd(xSamples[i], xSamples[j]))
2024-07-29 14:45:43 +02:00
}
}
2024-07-31 16:13:20 +02:00
term, err := gfDiv(num, denom)
2024-07-29 14:45:43 +02:00
if err != nil {
return 0, err
}
2024-07-31 16:13:20 +02:00
result = gfAdd(result, gfMult(ySamples[i], term))
2024-07-29 14:45:43 +02:00
}
return result, nil
}
// Helper functions for arithmetic in GF(2^8)
2024-07-31 16:13:20 +02:00
func gfAdd(a, b uint8) uint8 {
2024-07-29 14:45:43 +02:00
return a ^ b
}
2024-07-31 16:13:20 +02:00
func gfMult(a, b uint8) uint8 {
var product uint8
2024-07-29 14:45:43 +02:00
for b > 0 {
if b&1 == 1 {
2024-07-31 16:13:20 +02:00
product ^= a
2024-07-29 14:45:43 +02:00
}
if a&0x80 > 0 {
a = (a << 1) ^ 0x1B
} else {
a <<= 1
}
b >>= 1
}
2024-07-31 16:13:20 +02:00
return product
2024-07-29 14:45:43 +02:00
}
2024-07-31 16:13:20 +02:00
func gfDiv(a, b uint8) (uint8, error) {
2024-07-29 14:45:43 +02:00
if b == 0 {
2024-07-31 16:13:20 +02:00
return 0, ErrDivisionByZero
2024-07-29 14:45:43 +02:00
}
2024-07-31 16:13:20 +02:00
return gfMult(a, gfInverse(b)), nil
2024-07-29 14:45:43 +02:00
}
2024-07-31 16:13:20 +02:00
func gfInverse(a uint8) uint8 {
var inv uint8
for b := uint8(1); b != 0; b++ {
if gfMult(a, b) == 1 {
inv = b
2024-07-29 14:45:43 +02:00
break
}
}
2024-07-31 16:13:20 +02:00
return inv
2024-07-29 14:45:43 +02:00
}