181 lines
3.8 KiB
Go
181 lines
3.8 KiB
Go
package shamir
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"errors"
|
|
"fmt"
|
|
)
|
|
|
|
var (
|
|
ErrPartsLessThanThreshold = errors.New("number of parts cannot be less than the threshold")
|
|
ErrPartsExceedLimit = errors.New("number of parts cannot exceed 255")
|
|
ErrThresholdTooSmall = errors.New("threshold must be at least 2")
|
|
ErrThresholdExceedLimit = errors.New("threshold cannot exceed 255")
|
|
ErrEmptySecret = errors.New("cannot split an empty secret")
|
|
ErrInsufficientShares = errors.New("less than two shares cannot be used to reconstruct the secret")
|
|
ErrSharesTooShort = errors.New("shares must be at least two bytes long")
|
|
ErrInconsistentShareLength = errors.New("all shares must be the same length")
|
|
ErrDivisionByZero = errors.New("division by zero")
|
|
)
|
|
|
|
type polynomial struct {
|
|
coeffs []uint8
|
|
}
|
|
|
|
func newPolynomial(intercept, degree uint8) (*polynomial, error) {
|
|
p := &polynomial{
|
|
coeffs: make([]uint8, degree+1),
|
|
}
|
|
p.coeffs[0] = intercept
|
|
|
|
if _, err := rand.Read(p.coeffs[1:]); err != nil {
|
|
return nil, fmt.Errorf("failed to generate random coefficients: %w", err)
|
|
}
|
|
|
|
return p, nil
|
|
}
|
|
|
|
func (p *polynomial) evaluate(x uint8) uint8 {
|
|
result := p.coeffs[len(p.coeffs)-1]
|
|
for i := len(p.coeffs) - 2; i >= 0; i-- {
|
|
result = gfAdd(gfMult(result, x), p.coeffs[i])
|
|
}
|
|
return result
|
|
}
|
|
|
|
// Split divides the secret into n shares with a threshold t for reconstruction.
|
|
func Split(secret []byte, n, t int) ([][]byte, error) {
|
|
if n < t {
|
|
return nil, ErrPartsLessThanThreshold
|
|
}
|
|
if n > 255 {
|
|
return nil, ErrPartsExceedLimit
|
|
}
|
|
if t < 2 {
|
|
return nil, ErrThresholdTooSmall
|
|
}
|
|
if t > 255 {
|
|
return nil, ErrThresholdExceedLimit
|
|
}
|
|
if len(secret) == 0 {
|
|
return nil, ErrEmptySecret
|
|
}
|
|
|
|
shares := make([][]byte, n)
|
|
for i := range shares {
|
|
shares[i] = make([]byte, len(secret)+1)
|
|
shares[i][len(secret)] = uint8(i + 1)
|
|
}
|
|
|
|
for i, b := range secret {
|
|
p, err := newPolynomial(b, uint8(t-1))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
for j := range shares {
|
|
shares[j][i] = p.evaluate(uint8(j + 1))
|
|
}
|
|
}
|
|
|
|
return shares, nil
|
|
}
|
|
|
|
// Combine reconstructs the secret from the shares.
|
|
func Combine(shares [][]byte) ([]byte, error) {
|
|
if len(shares) < 2 {
|
|
return nil, ErrInsufficientShares
|
|
}
|
|
|
|
shareLen := len(shares[0])
|
|
if shareLen < 2 {
|
|
return nil, ErrSharesTooShort
|
|
}
|
|
|
|
for _, share := range shares {
|
|
if len(share) != shareLen {
|
|
return nil, ErrInconsistentShareLength
|
|
}
|
|
}
|
|
|
|
secret := make([]byte, shareLen-1)
|
|
xSamples := make([]uint8, len(shares))
|
|
ySamples := make([]uint8, len(shares))
|
|
|
|
for i, share := range shares {
|
|
xSamples[i] = share[shareLen-1]
|
|
}
|
|
|
|
for i := range secret {
|
|
for j, share := range shares {
|
|
ySamples[j] = share[i]
|
|
}
|
|
|
|
val, err := interpolatePolynomial(xSamples, ySamples, 0)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
secret[i] = val
|
|
}
|
|
|
|
return secret, nil
|
|
}
|
|
|
|
func interpolatePolynomial(xSamples, ySamples []uint8, x uint8) (uint8, error) {
|
|
result := uint8(0)
|
|
for i := range xSamples {
|
|
num, denom := uint8(1), uint8(1)
|
|
for j := range xSamples {
|
|
if i != j {
|
|
num = gfMult(num, gfAdd(x, xSamples[j]))
|
|
denom = gfMult(denom, gfAdd(xSamples[i], xSamples[j]))
|
|
}
|
|
}
|
|
term, err := gfDiv(num, denom)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
result = gfAdd(result, gfMult(ySamples[i], term))
|
|
}
|
|
return result, nil
|
|
}
|
|
|
|
// Helper functions for arithmetic in GF(2^8)
|
|
|
|
func gfAdd(a, b uint8) uint8 {
|
|
return a ^ b
|
|
}
|
|
|
|
func gfMult(a, b uint8) uint8 {
|
|
var product uint8
|
|
for b > 0 {
|
|
if b&1 == 1 {
|
|
product ^= a
|
|
}
|
|
if a&0x80 > 0 {
|
|
a = (a << 1) ^ 0x1B
|
|
} else {
|
|
a <<= 1
|
|
}
|
|
b >>= 1
|
|
}
|
|
return product
|
|
}
|
|
|
|
func gfDiv(a, b uint8) (uint8, error) {
|
|
if b == 0 {
|
|
return 0, ErrDivisionByZero
|
|
}
|
|
return gfMult(a, gfInverse(b)), nil
|
|
}
|
|
|
|
func gfInverse(a uint8) uint8 {
|
|
for b := uint8(1); b != 0; b++ {
|
|
if gfMult(a, b) == 1 {
|
|
return b
|
|
}
|
|
}
|
|
return 0 // This should never be reached if the input is non-zero
|
|
}
|