shamir/shamir.go
2024-07-31 23:29:17 +09:00

181 lines
3.8 KiB
Go

package shamir
import (
"crypto/rand"
"errors"
"fmt"
)
var (
ErrPartsLessThanThreshold = errors.New("number of parts cannot be less than the threshold")
ErrPartsExceedLimit = errors.New("number of parts cannot exceed 255")
ErrThresholdTooSmall = errors.New("threshold must be at least 2")
ErrThresholdExceedLimit = errors.New("threshold cannot exceed 255")
ErrEmptySecret = errors.New("cannot split an empty secret")
ErrInsufficientShares = errors.New("less than two shares cannot be used to reconstruct the secret")
ErrSharesTooShort = errors.New("shares must be at least two bytes long")
ErrInconsistentShareLength = errors.New("all shares must be the same length")
ErrDivisionByZero = errors.New("division by zero")
)
type polynomial struct {
coeffs []uint8
}
func newPolynomial(intercept, degree uint8) (*polynomial, error) {
p := &polynomial{
coeffs: make([]uint8, degree+1),
}
p.coeffs[0] = intercept
if _, err := rand.Read(p.coeffs[1:]); err != nil {
return nil, fmt.Errorf("failed to generate random coefficients: %w", err)
}
return p, nil
}
func (p *polynomial) evaluate(x uint8) uint8 {
result := p.coeffs[len(p.coeffs)-1]
for i := len(p.coeffs) - 2; i >= 0; i-- {
result = gfAdd(gfMult(result, x), p.coeffs[i])
}
return result
}
// Split divides the secret into n shares with a threshold t for reconstruction.
func Split(secret []byte, n, t int) ([][]byte, error) {
if n < t {
return nil, ErrPartsLessThanThreshold
}
if n > 255 {
return nil, ErrPartsExceedLimit
}
if t < 2 {
return nil, ErrThresholdTooSmall
}
if t > 255 {
return nil, ErrThresholdExceedLimit
}
if len(secret) == 0 {
return nil, ErrEmptySecret
}
shares := make([][]byte, n)
for i := range shares {
shares[i] = make([]byte, len(secret)+1)
shares[i][len(secret)] = uint8(i + 1)
}
for i, b := range secret {
p, err := newPolynomial(b, uint8(t-1))
if err != nil {
return nil, err
}
for j := range shares {
shares[j][i] = p.evaluate(uint8(j + 1))
}
}
return shares, nil
}
// Combine reconstructs the secret from the shares.
func Combine(shares [][]byte) ([]byte, error) {
if len(shares) < 2 {
return nil, ErrInsufficientShares
}
shareLen := len(shares[0])
if shareLen < 2 {
return nil, ErrSharesTooShort
}
for _, share := range shares {
if len(share) != shareLen {
return nil, ErrInconsistentShareLength
}
}
secret := make([]byte, shareLen-1)
xSamples := make([]uint8, len(shares))
ySamples := make([]uint8, len(shares))
for i, share := range shares {
xSamples[i] = share[shareLen-1]
}
for i := range secret {
for j, share := range shares {
ySamples[j] = share[i]
}
val, err := interpolatePolynomial(xSamples, ySamples, 0)
if err != nil {
return nil, err
}
secret[i] = val
}
return secret, nil
}
func interpolatePolynomial(xSamples, ySamples []uint8, x uint8) (uint8, error) {
result := uint8(0)
for i := range xSamples {
num, denom := uint8(1), uint8(1)
for j := range xSamples {
if i != j {
num = gfMult(num, gfAdd(x, xSamples[j]))
denom = gfMult(denom, gfAdd(xSamples[i], xSamples[j]))
}
}
term, err := gfDiv(num, denom)
if err != nil {
return 0, err
}
result = gfAdd(result, gfMult(ySamples[i], term))
}
return result, nil
}
// Helper functions for arithmetic in GF(2^8)
func gfAdd(a, b uint8) uint8 {
return a ^ b
}
func gfMult(a, b uint8) uint8 {
var product uint8
for b > 0 {
if b&1 == 1 {
product ^= a
}
if a&0x80 > 0 {
a = (a << 1) ^ 0x1B
} else {
a <<= 1
}
b >>= 1
}
return product
}
func gfDiv(a, b uint8) (uint8, error) {
if b == 0 {
return 0, ErrDivisionByZero
}
return gfMult(a, gfInverse(b)), nil
}
func gfInverse(a uint8) uint8 {
for b := uint8(1); b != 0; b++ {
if gfMult(a, b) == 1 {
return b
}
}
return 0 // This should never be reached if the input is non-zero
}