initial commit
This commit is contained in:
commit
80c4d9b00c
13
LICENSE
Normal file
13
LICENSE
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
Copyright (c) 2024 Kyodo Tech合同会社 <opensource@kyodo.tech>
|
||||||
|
|
||||||
|
Permission to use, copy, modify, and distribute this software for any
|
||||||
|
purpose with or without fee is hereby granted, provided that the above
|
||||||
|
copyright notice and this permission notice appear in all copies.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
|
||||||
|
WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||||
|
MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
|
||||||
|
ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||||
|
WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||||
|
ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
||||||
|
OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
49
README.md
Normal file
49
README.md
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
## Shamir's Secret Sharing
|
||||||
|
|
||||||
|
This repository provides a simple implementation of Shamir's Secret Sharing in Go, allowing to split a secret into multiple shares and reconstruct it using a subset of those shares.
|
||||||
|
|
||||||
|
- Split a secret into `N` shares with a threshold of `T` shares required to reconstruct the secret.
|
||||||
|
- Arithmetic operations in Galois Field (GF(`2^8`)).
|
||||||
|
- Polynomial creation and evaluation.
|
||||||
|
|
||||||
|
### Usage
|
||||||
|
|
||||||
|
```go
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"github.com/kyodo-tech/shamir"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
secret := []byte("my secret")
|
||||||
|
shares, err := shamir.Split(secret, 5, 3)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use any 3 out of 5 shares to reconstruct the secret
|
||||||
|
reconstructed, err := shamir.Combine(shares[:3])
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Reconstructed secret: %s\n", reconstructed)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Also see the two example programs in the `example` directory:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
go run example/split.go "my secret"
|
||||||
|
# Output:
|
||||||
|
# Share 0: lK8LyefHjxzOAQ==
|
||||||
|
# Share 1: 1Ws3IGJh/lVQAg==
|
||||||
|
# Share 2: LL0cmuDFAyzqAw==
|
||||||
|
# Share 3: Wf8RJQs43iALBA==
|
||||||
|
# Share 4: oCk6n4mcI1mxBQ==
|
||||||
|
go run ./example/combine.go lK8LyefHjxzOAQ==,LL0cmuDFAyzqAw==,oCk6n4mcI1mxBQ==
|
||||||
|
# Output:
|
||||||
|
# Reconstructed secret: my secret
|
||||||
|
```
|
39
example/combine.go
Normal file
39
example/combine.go
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/kyodo-tech/shamir"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
// read comma separated shares from command line
|
||||||
|
if len(os.Args) != 2 {
|
||||||
|
fmt.Println("Usage: go run main.go <share1>,<share2>,...,<shareN>")
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// split shares
|
||||||
|
sharesStr := strings.Split(os.Args[1], ",")
|
||||||
|
|
||||||
|
var shares [][]byte
|
||||||
|
for _, shareStr := range sharesStr {
|
||||||
|
share, err := base64.StdEncoding.DecodeString(shareStr)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
shares = append(shares, share)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use any 3 out of 5 shares to reconstruct the secret
|
||||||
|
reconstructed, err := shamir.Combine(shares[:3])
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Reconstructed secret: %s\n", reconstructed)
|
||||||
|
}
|
27
example/split.go
Normal file
27
example/split.go
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/kyodo-tech/shamir"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
if len(os.Args) != 2 {
|
||||||
|
fmt.Println("Usage: go run main.go <secret>")
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
secret := []byte(os.Args[1])
|
||||||
|
shares, err := shamir.Split(secret, 5, 3)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// print share strings
|
||||||
|
for i, share := range shares {
|
||||||
|
fmt.Printf("Share %d: %s\n", i, base64.StdEncoding.EncodeToString(share))
|
||||||
|
}
|
||||||
|
}
|
181
shamir.go
Normal file
181
shamir.go
Normal file
@ -0,0 +1,181 @@
|
|||||||
|
package shamir
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
type polynomial struct {
|
||||||
|
coefficients []uint8
|
||||||
|
}
|
||||||
|
|
||||||
|
func newPolynomial(intercept, degree uint8) (*polynomial, error) {
|
||||||
|
p := &polynomial{
|
||||||
|
coefficients: make([]uint8, degree+1),
|
||||||
|
}
|
||||||
|
p.coefficients[0] = intercept
|
||||||
|
|
||||||
|
if _, err := rand.Read(p.coefficients[1:]); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return p, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *polynomial) evaluate(x uint8) uint8 {
|
||||||
|
if x == 0 {
|
||||||
|
return p.coefficients[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
out := p.coefficients[len(p.coefficients)-1]
|
||||||
|
for i := len(p.coefficients) - 2; i >= 0; i-- {
|
||||||
|
out = add(mult(out, x), p.coefficients[i])
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// Split divides the secret into parts shares with a threshold of minimum shares to reconstruct the secret.
|
||||||
|
func Split(secret []byte, N, T int) ([][]byte, error) {
|
||||||
|
if N < T {
|
||||||
|
return nil, fmt.Errorf("parts cannot be less than threshold")
|
||||||
|
}
|
||||||
|
if N > 255 {
|
||||||
|
return nil, fmt.Errorf("parts cannot exceed 255")
|
||||||
|
}
|
||||||
|
if T < 2 {
|
||||||
|
return nil, fmt.Errorf("threshold must be at least 2")
|
||||||
|
}
|
||||||
|
if T > 255 {
|
||||||
|
return nil, fmt.Errorf("threshold cannot exceed 255")
|
||||||
|
}
|
||||||
|
if len(secret) == 0 {
|
||||||
|
return nil, fmt.Errorf("cannot split an empty secret")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate unique x-coordinates for each share
|
||||||
|
xCoordinates := make([]uint8, N)
|
||||||
|
for i := 0; i < N; i++ {
|
||||||
|
xCoordinates[i] = uint8(i + 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize shares with the secret length + 1 (for the x-coordinate)
|
||||||
|
shares := make([][]byte, N)
|
||||||
|
for i := range shares {
|
||||||
|
shares[i] = make([]byte, len(secret)+1)
|
||||||
|
shares[i][len(secret)] = xCoordinates[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a polynomial for each byte in the secret and evaluate it at each x-coordinate
|
||||||
|
for i, b := range secret {
|
||||||
|
p, err := newPolynomial(b, uint8(T-1))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for j := 0; j < N; j++ {
|
||||||
|
shares[j][i] = p.evaluate(xCoordinates[j])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return shares, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Combine reconstructs the secret from the provided shares.
|
||||||
|
func Combine(shares [][]byte) ([]byte, error) {
|
||||||
|
if len(shares) < 2 {
|
||||||
|
return nil, fmt.Errorf("less than two shares cannot be used to reconstruct the secret")
|
||||||
|
}
|
||||||
|
|
||||||
|
shareLength := len(shares[0])
|
||||||
|
if shareLength < 2 {
|
||||||
|
return nil, fmt.Errorf("shares must be at least two bytes long")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, share := range shares {
|
||||||
|
if len(share) != shareLength {
|
||||||
|
return nil, fmt.Errorf("all shares must be the same length")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
secret := make([]byte, shareLength-1)
|
||||||
|
xSamples := make([]uint8, len(shares))
|
||||||
|
ySamples := make([]uint8, len(shares))
|
||||||
|
|
||||||
|
for i, share := range shares {
|
||||||
|
xSamples[i] = share[shareLength-1]
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range secret {
|
||||||
|
for j, share := range shares {
|
||||||
|
ySamples[j] = share[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
val, err := interpolatePolynomialSafe(xSamples, ySamples, 0)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
secret[i] = val
|
||||||
|
}
|
||||||
|
|
||||||
|
return secret, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func interpolatePolynomialSafe(xSamples, ySamples []uint8, x uint8) (uint8, error) {
|
||||||
|
result := uint8(0)
|
||||||
|
for i := range xSamples {
|
||||||
|
num, denom := uint8(1), uint8(1)
|
||||||
|
for j := range xSamples {
|
||||||
|
if i != j {
|
||||||
|
num = mult(num, add(x, xSamples[j]))
|
||||||
|
denom = mult(denom, add(xSamples[i], xSamples[j]))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
term, err := div(num, denom)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
result = add(result, mult(ySamples[i], term))
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper functions for arithmetic in GF(2^8)
|
||||||
|
|
||||||
|
func add(a, b uint8) uint8 {
|
||||||
|
return a ^ b
|
||||||
|
}
|
||||||
|
|
||||||
|
func mult(a, b uint8) uint8 {
|
||||||
|
var p uint8
|
||||||
|
for b > 0 {
|
||||||
|
if b&1 == 1 {
|
||||||
|
p ^= a
|
||||||
|
}
|
||||||
|
if a&0x80 > 0 {
|
||||||
|
a = (a << 1) ^ 0x1B
|
||||||
|
} else {
|
||||||
|
a <<= 1
|
||||||
|
}
|
||||||
|
b >>= 1
|
||||||
|
}
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
func div(a, b uint8) (uint8, error) {
|
||||||
|
if b == 0 {
|
||||||
|
return 0, fmt.Errorf("division by zero")
|
||||||
|
}
|
||||||
|
return mult(a, inverse(b)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func inverse(a uint8) uint8 {
|
||||||
|
var b, c uint8
|
||||||
|
for b = 1; b != 0; b++ {
|
||||||
|
if mult(a, b) == 1 {
|
||||||
|
c = b
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return c
|
||||||
|
}
|
162
shamir_test.go
Normal file
162
shamir_test.go
Normal file
@ -0,0 +1,162 @@
|
|||||||
|
package shamir
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSplitInvalid(t *testing.T) {
|
||||||
|
secret := []byte("test")
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
parts int
|
||||||
|
threshold int
|
||||||
|
}{
|
||||||
|
{0, 0},
|
||||||
|
{2, 3},
|
||||||
|
{1000, 3},
|
||||||
|
{10, 1},
|
||||||
|
{3, 256},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
if _, err := Split(secret, tt.parts, tt.threshold); err == nil {
|
||||||
|
t.Fatalf("expected error for parts: %d, threshold: %d", tt.parts, tt.threshold)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := Split(nil, 3, 2); err == nil {
|
||||||
|
t.Fatalf("expected error for nil secret")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSplit(t *testing.T) {
|
||||||
|
secret := []byte("test")
|
||||||
|
|
||||||
|
out, err := Split(secret, 5, 3)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Split error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(out) != 5 {
|
||||||
|
t.Fatalf("expected 5 shares, got %d", len(out))
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, share := range out {
|
||||||
|
if len(share) != len(secret)+1 {
|
||||||
|
t.Fatalf("expected share length %d, got %d", len(secret)+1, len(share))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCombineInvalid(t *testing.T) {
|
||||||
|
tests := [][][]byte{
|
||||||
|
nil,
|
||||||
|
{[]byte("foo"), []byte("ba")},
|
||||||
|
{[]byte("f"), []byte("b")},
|
||||||
|
{[]byte("foo"), []byte("foo")},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, parts := range tests {
|
||||||
|
if _, err := Combine(parts); err == nil {
|
||||||
|
t.Fatalf("expected error for parts: %v", parts)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCombine(t *testing.T) {
|
||||||
|
secret := []byte("test")
|
||||||
|
|
||||||
|
out, err := Split(secret, 5, 3)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Split error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
for j := 0; j < 5; j++ {
|
||||||
|
if j == i {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for k := 0; k < 5; k++ {
|
||||||
|
if k == i || k == j {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
parts := [][]byte{out[i], out[j], out[k]}
|
||||||
|
recomb, err := Combine(parts)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Combine error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !bytes.Equal(recomb, secret) {
|
||||||
|
t.Fatalf("expected %v, got %v", secret, recomb)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFieldOperations(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
a, b, expected uint8
|
||||||
|
op func(uint8, uint8) (uint8, error)
|
||||||
|
}{
|
||||||
|
{16, 16, 0, func(a, b uint8) (uint8, error) { return add(a, b), nil }},
|
||||||
|
{3, 4, 7, func(a, b uint8) (uint8, error) { return add(a, b), nil }},
|
||||||
|
{3, 7, 9, func(a, b uint8) (uint8, error) { return mult(a, b), nil }},
|
||||||
|
{3, 0, 0, func(a, b uint8) (uint8, error) { return mult(a, b), nil }},
|
||||||
|
{0, 3, 0, func(a, b uint8) (uint8, error) { return mult(a, b), nil }},
|
||||||
|
{0, 7, 0, div},
|
||||||
|
{3, 3, 1, div},
|
||||||
|
{6, 3, 2, div},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
out, err := tt.op(tt.a, tt.b)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("operation error: %v", err)
|
||||||
|
}
|
||||||
|
if out != tt.expected {
|
||||||
|
t.Fatalf("expected %d, got %d", tt.expected, out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPolynomialCreationAndEvaluation(t *testing.T) {
|
||||||
|
p, err := newPolynomial(42, 1)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewPolynomial error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.coefficients[0] != 42 {
|
||||||
|
t.Fatalf("expected intercept 42, got %d", p.coefficients[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
if out := p.evaluate(0); out != 42 {
|
||||||
|
t.Fatalf("expected 42, got %d", out)
|
||||||
|
}
|
||||||
|
|
||||||
|
x := uint8(1)
|
||||||
|
expected := add(42, mult(x, p.coefficients[1]))
|
||||||
|
if out := p.evaluate(x); out != expected {
|
||||||
|
t.Fatalf("expected %d, got %d", expected, out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPolynomialInterpolation(t *testing.T) {
|
||||||
|
for i := 0; i < 256; i++ {
|
||||||
|
p, err := newPolynomial(uint8(i), 2)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewPolynomial error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
xVals := []uint8{1, 2, 3}
|
||||||
|
yVals := []uint8{p.evaluate(1), p.evaluate(2), p.evaluate(3)}
|
||||||
|
out, err := interpolatePolynomialSafe(xVals, yVals, 0)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("InterpolatePolynomialSafe error: %v", err)
|
||||||
|
}
|
||||||
|
if out != uint8(i) {
|
||||||
|
t.Fatalf("expected %d, got %d", uint8(i), out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user