initial commit

This commit is contained in:
saepire 2024-07-29 21:45:43 +09:00
commit 80c4d9b00c
7 changed files with 474 additions and 0 deletions

13
LICENSE Normal file
View File

@ -0,0 +1,13 @@
Copyright (c) 2024 Kyodo Tech合同会社 <opensource@kyodo.tech>
Permission to use, copy, modify, and distribute this software for any
purpose with or without fee is hereby granted, provided that the above
copyright notice and this permission notice appear in all copies.
THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.

49
README.md Normal file
View File

@ -0,0 +1,49 @@
## Shamir's Secret Sharing
This repository provides a simple implementation of Shamir's Secret Sharing in Go, allowing to split a secret into multiple shares and reconstruct it using a subset of those shares.
- Split a secret into `N` shares with a threshold of `T` shares required to reconstruct the secret.
- Arithmetic operations in Galois Field (GF(`2^8`)).
- Polynomial creation and evaluation.
### Usage
```go
package main
import (
"fmt"
"github.com/kyodo-tech/shamir"
)
func main() {
secret := []byte("my secret")
shares, err := shamir.Split(secret, 5, 3)
if err != nil {
panic(err)
}
// Use any 3 out of 5 shares to reconstruct the secret
reconstructed, err := shamir.Combine(shares[:3])
if err != nil {
panic(err)
}
fmt.Printf("Reconstructed secret: %s\n", reconstructed)
}
```
Also see the two example programs in the `example` directory:
```sh
go run example/split.go "my secret"
# Output:
# Share 0: lK8LyefHjxzOAQ==
# Share 1: 1Ws3IGJh/lVQAg==
# Share 2: LL0cmuDFAyzqAw==
# Share 3: Wf8RJQs43iALBA==
# Share 4: oCk6n4mcI1mxBQ==
go run ./example/combine.go lK8LyefHjxzOAQ==,LL0cmuDFAyzqAw==,oCk6n4mcI1mxBQ==
# Output:
# Reconstructed secret: my secret
```

39
example/combine.go Normal file
View File

@ -0,0 +1,39 @@
package main
import (
"encoding/base64"
"fmt"
"os"
"strings"
"github.com/kyodo-tech/shamir"
)
func main() {
// read comma separated shares from command line
if len(os.Args) != 2 {
fmt.Println("Usage: go run main.go <share1>,<share2>,...,<shareN>")
os.Exit(1)
}
// split shares
sharesStr := strings.Split(os.Args[1], ",")
var shares [][]byte
for _, shareStr := range sharesStr {
share, err := base64.StdEncoding.DecodeString(shareStr)
if err != nil {
panic(err)
}
shares = append(shares, share)
}
// Use any 3 out of 5 shares to reconstruct the secret
reconstructed, err := shamir.Combine(shares[:3])
if err != nil {
panic(err)
}
fmt.Printf("Reconstructed secret: %s\n", reconstructed)
}

27
example/split.go Normal file
View File

@ -0,0 +1,27 @@
package main
import (
"encoding/base64"
"fmt"
"os"
"github.com/kyodo-tech/shamir"
)
func main() {
if len(os.Args) != 2 {
fmt.Println("Usage: go run main.go <secret>")
os.Exit(1)
}
secret := []byte(os.Args[1])
shares, err := shamir.Split(secret, 5, 3)
if err != nil {
panic(err)
}
// print share strings
for i, share := range shares {
fmt.Printf("Share %d: %s\n", i, base64.StdEncoding.EncodeToString(share))
}
}

3
go.mod Normal file
View File

@ -0,0 +1,3 @@
module github.com/kyodo-tech/shamir
go 1.22.2

181
shamir.go Normal file
View File

@ -0,0 +1,181 @@
package shamir
import (
"crypto/rand"
"fmt"
)
type polynomial struct {
coefficients []uint8
}
func newPolynomial(intercept, degree uint8) (*polynomial, error) {
p := &polynomial{
coefficients: make([]uint8, degree+1),
}
p.coefficients[0] = intercept
if _, err := rand.Read(p.coefficients[1:]); err != nil {
return nil, err
}
return p, nil
}
func (p *polynomial) evaluate(x uint8) uint8 {
if x == 0 {
return p.coefficients[0]
}
out := p.coefficients[len(p.coefficients)-1]
for i := len(p.coefficients) - 2; i >= 0; i-- {
out = add(mult(out, x), p.coefficients[i])
}
return out
}
// Split divides the secret into parts shares with a threshold of minimum shares to reconstruct the secret.
func Split(secret []byte, N, T int) ([][]byte, error) {
if N < T {
return nil, fmt.Errorf("parts cannot be less than threshold")
}
if N > 255 {
return nil, fmt.Errorf("parts cannot exceed 255")
}
if T < 2 {
return nil, fmt.Errorf("threshold must be at least 2")
}
if T > 255 {
return nil, fmt.Errorf("threshold cannot exceed 255")
}
if len(secret) == 0 {
return nil, fmt.Errorf("cannot split an empty secret")
}
// Generate unique x-coordinates for each share
xCoordinates := make([]uint8, N)
for i := 0; i < N; i++ {
xCoordinates[i] = uint8(i + 1)
}
// Initialize shares with the secret length + 1 (for the x-coordinate)
shares := make([][]byte, N)
for i := range shares {
shares[i] = make([]byte, len(secret)+1)
shares[i][len(secret)] = xCoordinates[i]
}
// Create a polynomial for each byte in the secret and evaluate it at each x-coordinate
for i, b := range secret {
p, err := newPolynomial(b, uint8(T-1))
if err != nil {
return nil, err
}
for j := 0; j < N; j++ {
shares[j][i] = p.evaluate(xCoordinates[j])
}
}
return shares, nil
}
// Combine reconstructs the secret from the provided shares.
func Combine(shares [][]byte) ([]byte, error) {
if len(shares) < 2 {
return nil, fmt.Errorf("less than two shares cannot be used to reconstruct the secret")
}
shareLength := len(shares[0])
if shareLength < 2 {
return nil, fmt.Errorf("shares must be at least two bytes long")
}
for _, share := range shares {
if len(share) != shareLength {
return nil, fmt.Errorf("all shares must be the same length")
}
}
secret := make([]byte, shareLength-1)
xSamples := make([]uint8, len(shares))
ySamples := make([]uint8, len(shares))
for i, share := range shares {
xSamples[i] = share[shareLength-1]
}
for i := range secret {
for j, share := range shares {
ySamples[j] = share[i]
}
val, err := interpolatePolynomialSafe(xSamples, ySamples, 0)
if err != nil {
return nil, err
}
secret[i] = val
}
return secret, nil
}
func interpolatePolynomialSafe(xSamples, ySamples []uint8, x uint8) (uint8, error) {
result := uint8(0)
for i := range xSamples {
num, denom := uint8(1), uint8(1)
for j := range xSamples {
if i != j {
num = mult(num, add(x, xSamples[j]))
denom = mult(denom, add(xSamples[i], xSamples[j]))
}
}
term, err := div(num, denom)
if err != nil {
return 0, err
}
result = add(result, mult(ySamples[i], term))
}
return result, nil
}
// Helper functions for arithmetic in GF(2^8)
func add(a, b uint8) uint8 {
return a ^ b
}
func mult(a, b uint8) uint8 {
var p uint8
for b > 0 {
if b&1 == 1 {
p ^= a
}
if a&0x80 > 0 {
a = (a << 1) ^ 0x1B
} else {
a <<= 1
}
b >>= 1
}
return p
}
func div(a, b uint8) (uint8, error) {
if b == 0 {
return 0, fmt.Errorf("division by zero")
}
return mult(a, inverse(b)), nil
}
func inverse(a uint8) uint8 {
var b, c uint8
for b = 1; b != 0; b++ {
if mult(a, b) == 1 {
c = b
break
}
}
return c
}

162
shamir_test.go Normal file
View File

@ -0,0 +1,162 @@
package shamir
import (
"bytes"
"testing"
)
func TestSplitInvalid(t *testing.T) {
secret := []byte("test")
tests := []struct {
parts int
threshold int
}{
{0, 0},
{2, 3},
{1000, 3},
{10, 1},
{3, 256},
}
for _, tt := range tests {
if _, err := Split(secret, tt.parts, tt.threshold); err == nil {
t.Fatalf("expected error for parts: %d, threshold: %d", tt.parts, tt.threshold)
}
}
if _, err := Split(nil, 3, 2); err == nil {
t.Fatalf("expected error for nil secret")
}
}
func TestSplit(t *testing.T) {
secret := []byte("test")
out, err := Split(secret, 5, 3)
if err != nil {
t.Fatalf("Split error: %v", err)
}
if len(out) != 5 {
t.Fatalf("expected 5 shares, got %d", len(out))
}
for _, share := range out {
if len(share) != len(secret)+1 {
t.Fatalf("expected share length %d, got %d", len(secret)+1, len(share))
}
}
}
func TestCombineInvalid(t *testing.T) {
tests := [][][]byte{
nil,
{[]byte("foo"), []byte("ba")},
{[]byte("f"), []byte("b")},
{[]byte("foo"), []byte("foo")},
}
for _, parts := range tests {
if _, err := Combine(parts); err == nil {
t.Fatalf("expected error for parts: %v", parts)
}
}
}
func TestCombine(t *testing.T) {
secret := []byte("test")
out, err := Split(secret, 5, 3)
if err != nil {
t.Fatalf("Split error: %v", err)
}
for i := 0; i < 5; i++ {
for j := 0; j < 5; j++ {
if j == i {
continue
}
for k := 0; k < 5; k++ {
if k == i || k == j {
continue
}
parts := [][]byte{out[i], out[j], out[k]}
recomb, err := Combine(parts)
if err != nil {
t.Fatalf("Combine error: %v", err)
}
if !bytes.Equal(recomb, secret) {
t.Fatalf("expected %v, got %v", secret, recomb)
}
}
}
}
}
func TestFieldOperations(t *testing.T) {
tests := []struct {
a, b, expected uint8
op func(uint8, uint8) (uint8, error)
}{
{16, 16, 0, func(a, b uint8) (uint8, error) { return add(a, b), nil }},
{3, 4, 7, func(a, b uint8) (uint8, error) { return add(a, b), nil }},
{3, 7, 9, func(a, b uint8) (uint8, error) { return mult(a, b), nil }},
{3, 0, 0, func(a, b uint8) (uint8, error) { return mult(a, b), nil }},
{0, 3, 0, func(a, b uint8) (uint8, error) { return mult(a, b), nil }},
{0, 7, 0, div},
{3, 3, 1, div},
{6, 3, 2, div},
}
for _, tt := range tests {
out, err := tt.op(tt.a, tt.b)
if err != nil {
t.Fatalf("operation error: %v", err)
}
if out != tt.expected {
t.Fatalf("expected %d, got %d", tt.expected, out)
}
}
}
func TestPolynomialCreationAndEvaluation(t *testing.T) {
p, err := newPolynomial(42, 1)
if err != nil {
t.Fatalf("NewPolynomial error: %v", err)
}
if p.coefficients[0] != 42 {
t.Fatalf("expected intercept 42, got %d", p.coefficients[0])
}
if out := p.evaluate(0); out != 42 {
t.Fatalf("expected 42, got %d", out)
}
x := uint8(1)
expected := add(42, mult(x, p.coefficients[1]))
if out := p.evaluate(x); out != expected {
t.Fatalf("expected %d, got %d", expected, out)
}
}
func TestPolynomialInterpolation(t *testing.T) {
for i := 0; i < 256; i++ {
p, err := newPolynomial(uint8(i), 2)
if err != nil {
t.Fatalf("NewPolynomial error: %v", err)
}
xVals := []uint8{1, 2, 3}
yVals := []uint8{p.evaluate(1), p.evaluate(2), p.evaluate(3)}
out, err := interpolatePolynomialSafe(xVals, yVals, 0)
if err != nil {
t.Fatalf("InterpolatePolynomialSafe error: %v", err)
}
if out != uint8(i) {
t.Fatalf("expected %d, got %d", uint8(i), out)
}
}
}