export errors and refactor

This commit is contained in:
saepire 2024-07-31 23:13:20 +09:00
parent 0b6c46b3a6
commit 4975d23502
3 changed files with 85 additions and 75 deletions

View File

@ -1,6 +1,6 @@
## Shamir's Secret Sharing ## Shamir's Secret Sharing
This repository provides a simple implementation of Shamir's Secret Sharing in Go, allowing to split a secret into multiple shares and reconstruct it using a subset of those shares. This repository provides a minimal implementation of Shamir's Secret Sharing in Go, allowing to split a secret into multiple shares and reconstruct it using a subset of those shares.
- Split a secret into `N` shares with a threshold of `T` shares required to reconstruct the secret. - Split a secret into `N` shares with a threshold of `T` shares required to reconstruct the secret.
- Arithmetic operations in Galois Field (GF(`2^8`)). - Arithmetic operations in Galois Field (GF(`2^8`)).

132
shamir.go
View File

@ -2,21 +2,34 @@ package shamir
import ( import (
"crypto/rand" "crypto/rand"
"errors"
"fmt" "fmt"
) )
var (
ErrPartsLessThanThreshold = errors.New("number of parts cannot be less than the threshold")
ErrPartsExceedLimit = errors.New("number of parts cannot exceed 255")
ErrThresholdTooSmall = errors.New("threshold must be at least 2")
ErrThresholdExceedLimit = errors.New("threshold cannot exceed 255")
ErrEmptySecret = errors.New("cannot split an empty secret")
ErrInsufficientShares = errors.New("less than two shares cannot be used to reconstruct the secret")
ErrSharesTooShort = errors.New("shares must be at least two bytes long")
ErrInconsistentShareLength = errors.New("all shares must be the same length")
ErrDivisionByZero = errors.New("division by zero")
)
type polynomial struct { type polynomial struct {
coefficients []uint8 coeffs []uint8
} }
func newPolynomial(intercept, degree uint8) (*polynomial, error) { func newPolynomial(intercept, degree uint8) (*polynomial, error) {
p := &polynomial{ p := &polynomial{
coefficients: make([]uint8, degree+1), coeffs: make([]uint8, degree+1),
} }
p.coefficients[0] = intercept p.coeffs[0] = intercept
if _, err := rand.Read(p.coefficients[1:]); err != nil { if _, err := rand.Read(p.coeffs[1:]); err != nil {
return nil, err return nil, fmt.Errorf("failed to generate random coefficients: %w", err)
} }
return p, nil return p, nil
@ -24,85 +37,82 @@ func newPolynomial(intercept, degree uint8) (*polynomial, error) {
func (p *polynomial) evaluate(x uint8) uint8 { func (p *polynomial) evaluate(x uint8) uint8 {
if x == 0 { if x == 0 {
return p.coefficients[0] return p.coeffs[0]
} }
out := p.coefficients[len(p.coefficients)-1] result := p.coeffs[len(p.coeffs)-1]
for i := len(p.coefficients) - 2; i >= 0; i-- { for i := len(p.coeffs) - 2; i >= 0; i-- {
out = add(mult(out, x), p.coefficients[i]) result = gfAdd(gfMult(result, x), p.coeffs[i])
} }
return out return result
} }
// Split divides the secret into parts shares with a threshold of minimum shares to reconstruct the secret. // Split divides the secret into n shares with a threshold t for reconstruction.
func Split(secret []byte, N, T int) ([][]byte, error) { func Split(secret []byte, n, t int) ([][]byte, error) {
if N < T { if n < t {
return nil, fmt.Errorf("parts cannot be less than threshold") return nil, ErrPartsLessThanThreshold
} }
if N > 255 { if n > 255 {
return nil, fmt.Errorf("parts cannot exceed 255") return nil, ErrPartsExceedLimit
} }
if T < 2 { if t < 2 {
return nil, fmt.Errorf("threshold must be at least 2") return nil, ErrThresholdTooSmall
} }
if T > 255 { if t > 255 {
return nil, fmt.Errorf("threshold cannot exceed 255") return nil, ErrThresholdExceedLimit
} }
if len(secret) == 0 { if len(secret) == 0 {
return nil, fmt.Errorf("cannot split an empty secret") return nil, ErrEmptySecret
} }
// Generate unique x-coordinates for each share xCoords := make([]uint8, n)
xCoordinates := make([]uint8, N) for i := 0; i < n; i++ {
for i := 0; i < N; i++ { xCoords[i] = uint8(i + 1)
xCoordinates[i] = uint8(i + 1)
} }
// Initialize shares with the secret length + 1 (for the x-coordinate) shares := make([][]byte, n)
shares := make([][]byte, N)
for i := range shares { for i := range shares {
shares[i] = make([]byte, len(secret)+1) shares[i] = make([]byte, len(secret)+1)
shares[i][len(secret)] = xCoordinates[i] shares[i][len(secret)] = xCoords[i]
} }
// Create a polynomial for each byte in the secret and evaluate it at each x-coordinate
for i, b := range secret { for i, b := range secret {
p, err := newPolynomial(b, uint8(T-1)) p, err := newPolynomial(b, uint8(t-1))
if err != nil { if err != nil {
return nil, err return nil, err
} }
for j := 0; j < N; j++ { for j := 0; j < n; j++ {
shares[j][i] = p.evaluate(xCoordinates[j]) shares[j][i] = p.evaluate(xCoords[j])
} }
} }
return shares, nil return shares, nil
} }
// Combine reconstructs the secret from the provided shares. // Combine reconstructs the secret from the shares.
func Combine(shares [][]byte) ([]byte, error) { func Combine(shares [][]byte) ([]byte, error) {
if len(shares) < 2 { if len(shares) < 2 {
return nil, fmt.Errorf("less than two shares cannot be used to reconstruct the secret") return nil, ErrInsufficientShares
} }
shareLength := len(shares[0]) shareLen := len(shares[0])
if shareLength < 2 { if shareLen < 2 {
return nil, fmt.Errorf("shares must be at least two bytes long") return nil, ErrSharesTooShort
} }
for _, share := range shares { for _, share := range shares {
if len(share) != shareLength { if len(share) != shareLen {
return nil, fmt.Errorf("all shares must be the same length") return nil, ErrInconsistentShareLength
} }
} }
secret := make([]byte, shareLength-1) secret := make([]byte, shareLen-1)
xSamples := make([]uint8, len(shares)) xSamples := make([]uint8, len(shares))
ySamples := make([]uint8, len(shares)) ySamples := make([]uint8, len(shares))
for i, share := range shares { for i, share := range shares {
xSamples[i] = share[shareLength-1] xSamples[i] = share[shareLen-1]
} }
for i := range secret { for i := range secret {
@ -110,7 +120,7 @@ func Combine(shares [][]byte) ([]byte, error) {
ySamples[j] = share[i] ySamples[j] = share[i]
} }
val, err := interpolatePolynomialSafe(xSamples, ySamples, 0) val, err := interpolatePolynomial(xSamples, ySamples, 0)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -121,36 +131,36 @@ func Combine(shares [][]byte) ([]byte, error) {
return secret, nil return secret, nil
} }
func interpolatePolynomialSafe(xSamples, ySamples []uint8, x uint8) (uint8, error) { func interpolatePolynomial(xSamples, ySamples []uint8, x uint8) (uint8, error) {
result := uint8(0) result := uint8(0)
for i := range xSamples { for i := range xSamples {
num, denom := uint8(1), uint8(1) num, denom := uint8(1), uint8(1)
for j := range xSamples { for j := range xSamples {
if i != j { if i != j {
num = mult(num, add(x, xSamples[j])) num = gfMult(num, gfAdd(x, xSamples[j]))
denom = mult(denom, add(xSamples[i], xSamples[j])) denom = gfMult(denom, gfAdd(xSamples[i], xSamples[j]))
} }
} }
term, err := div(num, denom) term, err := gfDiv(num, denom)
if err != nil { if err != nil {
return 0, err return 0, err
} }
result = add(result, mult(ySamples[i], term)) result = gfAdd(result, gfMult(ySamples[i], term))
} }
return result, nil return result, nil
} }
// Helper functions for arithmetic in GF(2^8) // Helper functions for arithmetic in GF(2^8)
func add(a, b uint8) uint8 { func gfAdd(a, b uint8) uint8 {
return a ^ b return a ^ b
} }
func mult(a, b uint8) uint8 { func gfMult(a, b uint8) uint8 {
var p uint8 var product uint8
for b > 0 { for b > 0 {
if b&1 == 1 { if b&1 == 1 {
p ^= a product ^= a
} }
if a&0x80 > 0 { if a&0x80 > 0 {
a = (a << 1) ^ 0x1B a = (a << 1) ^ 0x1B
@ -159,23 +169,23 @@ func mult(a, b uint8) uint8 {
} }
b >>= 1 b >>= 1
} }
return p return product
} }
func div(a, b uint8) (uint8, error) { func gfDiv(a, b uint8) (uint8, error) {
if b == 0 { if b == 0 {
return 0, fmt.Errorf("division by zero") return 0, ErrDivisionByZero
} }
return mult(a, inverse(b)), nil return gfMult(a, gfInverse(b)), nil
} }
func inverse(a uint8) uint8 { func gfInverse(a uint8) uint8 {
var b, c uint8 var inv uint8
for b = 1; b != 0; b++ { for b := uint8(1); b != 0; b++ {
if mult(a, b) == 1 { if gfMult(a, b) == 1 {
c = b inv = b
break break
} }
} }
return c return inv
} }

View File

@ -100,14 +100,14 @@ func TestFieldOperations(t *testing.T) {
a, b, expected uint8 a, b, expected uint8
op func(uint8, uint8) (uint8, error) op func(uint8, uint8) (uint8, error)
}{ }{
{16, 16, 0, func(a, b uint8) (uint8, error) { return add(a, b), nil }}, {16, 16, 0, func(a, b uint8) (uint8, error) { return gfAdd(a, b), nil }},
{3, 4, 7, func(a, b uint8) (uint8, error) { return add(a, b), nil }}, {3, 4, 7, func(a, b uint8) (uint8, error) { return gfAdd(a, b), nil }},
{3, 7, 9, func(a, b uint8) (uint8, error) { return mult(a, b), nil }}, {3, 7, 9, func(a, b uint8) (uint8, error) { return gfMult(a, b), nil }},
{3, 0, 0, func(a, b uint8) (uint8, error) { return mult(a, b), nil }}, {3, 0, 0, func(a, b uint8) (uint8, error) { return gfMult(a, b), nil }},
{0, 3, 0, func(a, b uint8) (uint8, error) { return mult(a, b), nil }}, {0, 3, 0, func(a, b uint8) (uint8, error) { return gfMult(a, b), nil }},
{0, 7, 0, div}, {0, 7, 0, gfDiv},
{3, 3, 1, div}, {3, 3, 1, gfDiv},
{6, 3, 2, div}, {6, 3, 2, gfDiv},
} }
for _, tt := range tests { for _, tt := range tests {
@ -127,8 +127,8 @@ func TestPolynomialCreationAndEvaluation(t *testing.T) {
t.Fatalf("NewPolynomial error: %v", err) t.Fatalf("NewPolynomial error: %v", err)
} }
if p.coefficients[0] != 42 { if p.coeffs[0] != 42 {
t.Fatalf("expected intercept 42, got %d", p.coefficients[0]) t.Fatalf("expected intercept 42, got %d", p.coeffs[0])
} }
if out := p.evaluate(0); out != 42 { if out := p.evaluate(0); out != 42 {
@ -136,7 +136,7 @@ func TestPolynomialCreationAndEvaluation(t *testing.T) {
} }
x := uint8(1) x := uint8(1)
expected := add(42, mult(x, p.coefficients[1])) expected := gfAdd(42, gfMult(x, p.coeffs[1]))
if out := p.evaluate(x); out != expected { if out := p.evaluate(x); out != expected {
t.Fatalf("expected %d, got %d", expected, out) t.Fatalf("expected %d, got %d", expected, out)
} }
@ -151,9 +151,9 @@ func TestPolynomialInterpolation(t *testing.T) {
xVals := []uint8{1, 2, 3} xVals := []uint8{1, 2, 3}
yVals := []uint8{p.evaluate(1), p.evaluate(2), p.evaluate(3)} yVals := []uint8{p.evaluate(1), p.evaluate(2), p.evaluate(3)}
out, err := interpolatePolynomialSafe(xVals, yVals, 0) out, err := interpolatePolynomial(xVals, yVals, 0)
if err != nil { if err != nil {
t.Fatalf("InterpolatePolynomialSafe error: %v", err) t.Fatalf("interpolatePolynomial error: %v", err)
} }
if out != uint8(i) { if out != uint8(i) {
t.Fatalf("expected %d, got %d", uint8(i), out) t.Fatalf("expected %d, got %d", uint8(i), out)