export errors and refactor
This commit is contained in:
parent
0b6c46b3a6
commit
4975d23502
@ -1,6 +1,6 @@
|
|||||||
## Shamir's Secret Sharing
|
## Shamir's Secret Sharing
|
||||||
|
|
||||||
This repository provides a simple implementation of Shamir's Secret Sharing in Go, allowing to split a secret into multiple shares and reconstruct it using a subset of those shares.
|
This repository provides a minimal implementation of Shamir's Secret Sharing in Go, allowing to split a secret into multiple shares and reconstruct it using a subset of those shares.
|
||||||
|
|
||||||
- Split a secret into `N` shares with a threshold of `T` shares required to reconstruct the secret.
|
- Split a secret into `N` shares with a threshold of `T` shares required to reconstruct the secret.
|
||||||
- Arithmetic operations in Galois Field (GF(`2^8`)).
|
- Arithmetic operations in Galois Field (GF(`2^8`)).
|
||||||
|
132
shamir.go
132
shamir.go
@ -2,21 +2,34 @@ package shamir
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrPartsLessThanThreshold = errors.New("number of parts cannot be less than the threshold")
|
||||||
|
ErrPartsExceedLimit = errors.New("number of parts cannot exceed 255")
|
||||||
|
ErrThresholdTooSmall = errors.New("threshold must be at least 2")
|
||||||
|
ErrThresholdExceedLimit = errors.New("threshold cannot exceed 255")
|
||||||
|
ErrEmptySecret = errors.New("cannot split an empty secret")
|
||||||
|
ErrInsufficientShares = errors.New("less than two shares cannot be used to reconstruct the secret")
|
||||||
|
ErrSharesTooShort = errors.New("shares must be at least two bytes long")
|
||||||
|
ErrInconsistentShareLength = errors.New("all shares must be the same length")
|
||||||
|
ErrDivisionByZero = errors.New("division by zero")
|
||||||
|
)
|
||||||
|
|
||||||
type polynomial struct {
|
type polynomial struct {
|
||||||
coefficients []uint8
|
coeffs []uint8
|
||||||
}
|
}
|
||||||
|
|
||||||
func newPolynomial(intercept, degree uint8) (*polynomial, error) {
|
func newPolynomial(intercept, degree uint8) (*polynomial, error) {
|
||||||
p := &polynomial{
|
p := &polynomial{
|
||||||
coefficients: make([]uint8, degree+1),
|
coeffs: make([]uint8, degree+1),
|
||||||
}
|
}
|
||||||
p.coefficients[0] = intercept
|
p.coeffs[0] = intercept
|
||||||
|
|
||||||
if _, err := rand.Read(p.coefficients[1:]); err != nil {
|
if _, err := rand.Read(p.coeffs[1:]); err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("failed to generate random coefficients: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return p, nil
|
return p, nil
|
||||||
@ -24,85 +37,82 @@ func newPolynomial(intercept, degree uint8) (*polynomial, error) {
|
|||||||
|
|
||||||
func (p *polynomial) evaluate(x uint8) uint8 {
|
func (p *polynomial) evaluate(x uint8) uint8 {
|
||||||
if x == 0 {
|
if x == 0 {
|
||||||
return p.coefficients[0]
|
return p.coeffs[0]
|
||||||
}
|
}
|
||||||
|
|
||||||
out := p.coefficients[len(p.coefficients)-1]
|
result := p.coeffs[len(p.coeffs)-1]
|
||||||
for i := len(p.coefficients) - 2; i >= 0; i-- {
|
for i := len(p.coeffs) - 2; i >= 0; i-- {
|
||||||
out = add(mult(out, x), p.coefficients[i])
|
result = gfAdd(gfMult(result, x), p.coeffs[i])
|
||||||
}
|
}
|
||||||
return out
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
// Split divides the secret into parts shares with a threshold of minimum shares to reconstruct the secret.
|
// Split divides the secret into n shares with a threshold t for reconstruction.
|
||||||
func Split(secret []byte, N, T int) ([][]byte, error) {
|
func Split(secret []byte, n, t int) ([][]byte, error) {
|
||||||
if N < T {
|
if n < t {
|
||||||
return nil, fmt.Errorf("parts cannot be less than threshold")
|
return nil, ErrPartsLessThanThreshold
|
||||||
}
|
}
|
||||||
if N > 255 {
|
if n > 255 {
|
||||||
return nil, fmt.Errorf("parts cannot exceed 255")
|
return nil, ErrPartsExceedLimit
|
||||||
}
|
}
|
||||||
if T < 2 {
|
if t < 2 {
|
||||||
return nil, fmt.Errorf("threshold must be at least 2")
|
return nil, ErrThresholdTooSmall
|
||||||
}
|
}
|
||||||
if T > 255 {
|
if t > 255 {
|
||||||
return nil, fmt.Errorf("threshold cannot exceed 255")
|
return nil, ErrThresholdExceedLimit
|
||||||
}
|
}
|
||||||
if len(secret) == 0 {
|
if len(secret) == 0 {
|
||||||
return nil, fmt.Errorf("cannot split an empty secret")
|
return nil, ErrEmptySecret
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate unique x-coordinates for each share
|
xCoords := make([]uint8, n)
|
||||||
xCoordinates := make([]uint8, N)
|
for i := 0; i < n; i++ {
|
||||||
for i := 0; i < N; i++ {
|
xCoords[i] = uint8(i + 1)
|
||||||
xCoordinates[i] = uint8(i + 1)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize shares with the secret length + 1 (for the x-coordinate)
|
shares := make([][]byte, n)
|
||||||
shares := make([][]byte, N)
|
|
||||||
for i := range shares {
|
for i := range shares {
|
||||||
shares[i] = make([]byte, len(secret)+1)
|
shares[i] = make([]byte, len(secret)+1)
|
||||||
shares[i][len(secret)] = xCoordinates[i]
|
shares[i][len(secret)] = xCoords[i]
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a polynomial for each byte in the secret and evaluate it at each x-coordinate
|
|
||||||
for i, b := range secret {
|
for i, b := range secret {
|
||||||
p, err := newPolynomial(b, uint8(T-1))
|
p, err := newPolynomial(b, uint8(t-1))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
for j := 0; j < N; j++ {
|
for j := 0; j < n; j++ {
|
||||||
shares[j][i] = p.evaluate(xCoordinates[j])
|
shares[j][i] = p.evaluate(xCoords[j])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return shares, nil
|
return shares, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Combine reconstructs the secret from the provided shares.
|
// Combine reconstructs the secret from the shares.
|
||||||
func Combine(shares [][]byte) ([]byte, error) {
|
func Combine(shares [][]byte) ([]byte, error) {
|
||||||
if len(shares) < 2 {
|
if len(shares) < 2 {
|
||||||
return nil, fmt.Errorf("less than two shares cannot be used to reconstruct the secret")
|
return nil, ErrInsufficientShares
|
||||||
}
|
}
|
||||||
|
|
||||||
shareLength := len(shares[0])
|
shareLen := len(shares[0])
|
||||||
if shareLength < 2 {
|
if shareLen < 2 {
|
||||||
return nil, fmt.Errorf("shares must be at least two bytes long")
|
return nil, ErrSharesTooShort
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, share := range shares {
|
for _, share := range shares {
|
||||||
if len(share) != shareLength {
|
if len(share) != shareLen {
|
||||||
return nil, fmt.Errorf("all shares must be the same length")
|
return nil, ErrInconsistentShareLength
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
secret := make([]byte, shareLength-1)
|
secret := make([]byte, shareLen-1)
|
||||||
xSamples := make([]uint8, len(shares))
|
xSamples := make([]uint8, len(shares))
|
||||||
ySamples := make([]uint8, len(shares))
|
ySamples := make([]uint8, len(shares))
|
||||||
|
|
||||||
for i, share := range shares {
|
for i, share := range shares {
|
||||||
xSamples[i] = share[shareLength-1]
|
xSamples[i] = share[shareLen-1]
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := range secret {
|
for i := range secret {
|
||||||
@ -110,7 +120,7 @@ func Combine(shares [][]byte) ([]byte, error) {
|
|||||||
ySamples[j] = share[i]
|
ySamples[j] = share[i]
|
||||||
}
|
}
|
||||||
|
|
||||||
val, err := interpolatePolynomialSafe(xSamples, ySamples, 0)
|
val, err := interpolatePolynomial(xSamples, ySamples, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -121,36 +131,36 @@ func Combine(shares [][]byte) ([]byte, error) {
|
|||||||
return secret, nil
|
return secret, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func interpolatePolynomialSafe(xSamples, ySamples []uint8, x uint8) (uint8, error) {
|
func interpolatePolynomial(xSamples, ySamples []uint8, x uint8) (uint8, error) {
|
||||||
result := uint8(0)
|
result := uint8(0)
|
||||||
for i := range xSamples {
|
for i := range xSamples {
|
||||||
num, denom := uint8(1), uint8(1)
|
num, denom := uint8(1), uint8(1)
|
||||||
for j := range xSamples {
|
for j := range xSamples {
|
||||||
if i != j {
|
if i != j {
|
||||||
num = mult(num, add(x, xSamples[j]))
|
num = gfMult(num, gfAdd(x, xSamples[j]))
|
||||||
denom = mult(denom, add(xSamples[i], xSamples[j]))
|
denom = gfMult(denom, gfAdd(xSamples[i], xSamples[j]))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
term, err := div(num, denom)
|
term, err := gfDiv(num, denom)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
result = add(result, mult(ySamples[i], term))
|
result = gfAdd(result, gfMult(ySamples[i], term))
|
||||||
}
|
}
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper functions for arithmetic in GF(2^8)
|
// Helper functions for arithmetic in GF(2^8)
|
||||||
|
|
||||||
func add(a, b uint8) uint8 {
|
func gfAdd(a, b uint8) uint8 {
|
||||||
return a ^ b
|
return a ^ b
|
||||||
}
|
}
|
||||||
|
|
||||||
func mult(a, b uint8) uint8 {
|
func gfMult(a, b uint8) uint8 {
|
||||||
var p uint8
|
var product uint8
|
||||||
for b > 0 {
|
for b > 0 {
|
||||||
if b&1 == 1 {
|
if b&1 == 1 {
|
||||||
p ^= a
|
product ^= a
|
||||||
}
|
}
|
||||||
if a&0x80 > 0 {
|
if a&0x80 > 0 {
|
||||||
a = (a << 1) ^ 0x1B
|
a = (a << 1) ^ 0x1B
|
||||||
@ -159,23 +169,23 @@ func mult(a, b uint8) uint8 {
|
|||||||
}
|
}
|
||||||
b >>= 1
|
b >>= 1
|
||||||
}
|
}
|
||||||
return p
|
return product
|
||||||
}
|
}
|
||||||
|
|
||||||
func div(a, b uint8) (uint8, error) {
|
func gfDiv(a, b uint8) (uint8, error) {
|
||||||
if b == 0 {
|
if b == 0 {
|
||||||
return 0, fmt.Errorf("division by zero")
|
return 0, ErrDivisionByZero
|
||||||
}
|
}
|
||||||
return mult(a, inverse(b)), nil
|
return gfMult(a, gfInverse(b)), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func inverse(a uint8) uint8 {
|
func gfInverse(a uint8) uint8 {
|
||||||
var b, c uint8
|
var inv uint8
|
||||||
for b = 1; b != 0; b++ {
|
for b := uint8(1); b != 0; b++ {
|
||||||
if mult(a, b) == 1 {
|
if gfMult(a, b) == 1 {
|
||||||
c = b
|
inv = b
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return c
|
return inv
|
||||||
}
|
}
|
||||||
|
@ -100,14 +100,14 @@ func TestFieldOperations(t *testing.T) {
|
|||||||
a, b, expected uint8
|
a, b, expected uint8
|
||||||
op func(uint8, uint8) (uint8, error)
|
op func(uint8, uint8) (uint8, error)
|
||||||
}{
|
}{
|
||||||
{16, 16, 0, func(a, b uint8) (uint8, error) { return add(a, b), nil }},
|
{16, 16, 0, func(a, b uint8) (uint8, error) { return gfAdd(a, b), nil }},
|
||||||
{3, 4, 7, func(a, b uint8) (uint8, error) { return add(a, b), nil }},
|
{3, 4, 7, func(a, b uint8) (uint8, error) { return gfAdd(a, b), nil }},
|
||||||
{3, 7, 9, func(a, b uint8) (uint8, error) { return mult(a, b), nil }},
|
{3, 7, 9, func(a, b uint8) (uint8, error) { return gfMult(a, b), nil }},
|
||||||
{3, 0, 0, func(a, b uint8) (uint8, error) { return mult(a, b), nil }},
|
{3, 0, 0, func(a, b uint8) (uint8, error) { return gfMult(a, b), nil }},
|
||||||
{0, 3, 0, func(a, b uint8) (uint8, error) { return mult(a, b), nil }},
|
{0, 3, 0, func(a, b uint8) (uint8, error) { return gfMult(a, b), nil }},
|
||||||
{0, 7, 0, div},
|
{0, 7, 0, gfDiv},
|
||||||
{3, 3, 1, div},
|
{3, 3, 1, gfDiv},
|
||||||
{6, 3, 2, div},
|
{6, 3, 2, gfDiv},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
@ -127,8 +127,8 @@ func TestPolynomialCreationAndEvaluation(t *testing.T) {
|
|||||||
t.Fatalf("NewPolynomial error: %v", err)
|
t.Fatalf("NewPolynomial error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if p.coefficients[0] != 42 {
|
if p.coeffs[0] != 42 {
|
||||||
t.Fatalf("expected intercept 42, got %d", p.coefficients[0])
|
t.Fatalf("expected intercept 42, got %d", p.coeffs[0])
|
||||||
}
|
}
|
||||||
|
|
||||||
if out := p.evaluate(0); out != 42 {
|
if out := p.evaluate(0); out != 42 {
|
||||||
@ -136,7 +136,7 @@ func TestPolynomialCreationAndEvaluation(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
x := uint8(1)
|
x := uint8(1)
|
||||||
expected := add(42, mult(x, p.coefficients[1]))
|
expected := gfAdd(42, gfMult(x, p.coeffs[1]))
|
||||||
if out := p.evaluate(x); out != expected {
|
if out := p.evaluate(x); out != expected {
|
||||||
t.Fatalf("expected %d, got %d", expected, out)
|
t.Fatalf("expected %d, got %d", expected, out)
|
||||||
}
|
}
|
||||||
@ -151,9 +151,9 @@ func TestPolynomialInterpolation(t *testing.T) {
|
|||||||
|
|
||||||
xVals := []uint8{1, 2, 3}
|
xVals := []uint8{1, 2, 3}
|
||||||
yVals := []uint8{p.evaluate(1), p.evaluate(2), p.evaluate(3)}
|
yVals := []uint8{p.evaluate(1), p.evaluate(2), p.evaluate(3)}
|
||||||
out, err := interpolatePolynomialSafe(xVals, yVals, 0)
|
out, err := interpolatePolynomial(xVals, yVals, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("InterpolatePolynomialSafe error: %v", err)
|
t.Fatalf("interpolatePolynomial error: %v", err)
|
||||||
}
|
}
|
||||||
if out != uint8(i) {
|
if out != uint8(i) {
|
||||||
t.Fatalf("expected %d, got %d", uint8(i), out)
|
t.Fatalf("expected %d, got %d", uint8(i), out)
|
||||||
|
Loading…
Reference in New Issue
Block a user