...

Source file src/database/sql/convert_test.go

Documentation: database/sql

		 1  // Copyright 2011 The Go Authors. All rights reserved.
		 2  // Use of this source code is governed by a BSD-style
		 3  // license that can be found in the LICENSE file.
		 4  
		 5  package sql
		 6  
		 7  import (
		 8  	"database/sql/driver"
		 9  	"fmt"
		10  	"reflect"
		11  	"runtime"
		12  	"strings"
		13  	"sync"
		14  	"testing"
		15  	"time"
		16  )
		17  
		18  var someTime = time.Unix(123, 0)
		19  var answer int64 = 42
		20  
		21  type (
		22  	userDefined			 float64
		23  	userDefinedSlice	[]int
		24  	userDefinedString string
		25  )
		26  
		27  type conversionTest struct {
		28  	s, d interface{} // source and destination
		29  
		30  	// following are used if they're non-zero
		31  	wantint		int64
		32  	wantuint	 uint64
		33  	wantstr		string
		34  	wantbytes	[]byte
		35  	wantraw		RawBytes
		36  	wantf32		float32
		37  	wantf64		float64
		38  	wanttime	 time.Time
		39  	wantbool	 bool // used if d is of type *bool
		40  	wanterr		string
		41  	wantiface	interface{}
		42  	wantptr		*int64 // if non-nil, *d's pointed value must be equal to *wantptr
		43  	wantnil		bool	 // if true, *d must be *int64(nil)
		44  	wantusrdef userDefined
		45  	wantusrstr userDefinedString
		46  }
		47  
		48  // Target variables for scanning into.
		49  var (
		50  	scanstr		string
		51  	scanbytes	[]byte
		52  	scanraw		RawBytes
		53  	scanint		int
		54  	scanint8	 int8
		55  	scanint16	int16
		56  	scanint32	int32
		57  	scanuint8	uint8
		58  	scanuint16 uint16
		59  	scanbool	 bool
		60  	scanf32		float32
		61  	scanf64		float64
		62  	scantime	 time.Time
		63  	scanptr		*int64
		64  	scaniface	interface{}
		65  )
		66  
		67  func conversionTests() []conversionTest {
		68  	// Return a fresh instance to test so "go test -count 2" works correctly.
		69  	return []conversionTest{
		70  		// Exact conversions (destination pointer type matches source type)
		71  		{s: "foo", d: &scanstr, wantstr: "foo"},
		72  		{s: 123, d: &scanint, wantint: 123},
		73  		{s: someTime, d: &scantime, wanttime: someTime},
		74  
		75  		// To strings
		76  		{s: "string", d: &scanstr, wantstr: "string"},
		77  		{s: []byte("byteslice"), d: &scanstr, wantstr: "byteslice"},
		78  		{s: 123, d: &scanstr, wantstr: "123"},
		79  		{s: int8(123), d: &scanstr, wantstr: "123"},
		80  		{s: int64(123), d: &scanstr, wantstr: "123"},
		81  		{s: uint8(123), d: &scanstr, wantstr: "123"},
		82  		{s: uint16(123), d: &scanstr, wantstr: "123"},
		83  		{s: uint32(123), d: &scanstr, wantstr: "123"},
		84  		{s: uint64(123), d: &scanstr, wantstr: "123"},
		85  		{s: 1.5, d: &scanstr, wantstr: "1.5"},
		86  
		87  		// From time.Time:
		88  		{s: time.Unix(1, 0).UTC(), d: &scanstr, wantstr: "1970-01-01T00:00:01Z"},
		89  		{s: time.Unix(1453874597, 0).In(time.FixedZone("here", -3600*8)), d: &scanstr, wantstr: "2016-01-26T22:03:17-08:00"},
		90  		{s: time.Unix(1, 2).UTC(), d: &scanstr, wantstr: "1970-01-01T00:00:01.000000002Z"},
		91  		{s: time.Time{}, d: &scanstr, wantstr: "0001-01-01T00:00:00Z"},
		92  		{s: time.Unix(1, 2).UTC(), d: &scanbytes, wantbytes: []byte("1970-01-01T00:00:01.000000002Z")},
		93  		{s: time.Unix(1, 2).UTC(), d: &scaniface, wantiface: time.Unix(1, 2).UTC()},
		94  
		95  		// To []byte
		96  		{s: nil, d: &scanbytes, wantbytes: nil},
		97  		{s: "string", d: &scanbytes, wantbytes: []byte("string")},
		98  		{s: []byte("byteslice"), d: &scanbytes, wantbytes: []byte("byteslice")},
		99  		{s: 123, d: &scanbytes, wantbytes: []byte("123")},
	 100  		{s: int8(123), d: &scanbytes, wantbytes: []byte("123")},
	 101  		{s: int64(123), d: &scanbytes, wantbytes: []byte("123")},
	 102  		{s: uint8(123), d: &scanbytes, wantbytes: []byte("123")},
	 103  		{s: uint16(123), d: &scanbytes, wantbytes: []byte("123")},
	 104  		{s: uint32(123), d: &scanbytes, wantbytes: []byte("123")},
	 105  		{s: uint64(123), d: &scanbytes, wantbytes: []byte("123")},
	 106  		{s: 1.5, d: &scanbytes, wantbytes: []byte("1.5")},
	 107  
	 108  		// To RawBytes
	 109  		{s: nil, d: &scanraw, wantraw: nil},
	 110  		{s: []byte("byteslice"), d: &scanraw, wantraw: RawBytes("byteslice")},
	 111  		{s: "string", d: &scanraw, wantraw: RawBytes("string")},
	 112  		{s: 123, d: &scanraw, wantraw: RawBytes("123")},
	 113  		{s: int8(123), d: &scanraw, wantraw: RawBytes("123")},
	 114  		{s: int64(123), d: &scanraw, wantraw: RawBytes("123")},
	 115  		{s: uint8(123), d: &scanraw, wantraw: RawBytes("123")},
	 116  		{s: uint16(123), d: &scanraw, wantraw: RawBytes("123")},
	 117  		{s: uint32(123), d: &scanraw, wantraw: RawBytes("123")},
	 118  		{s: uint64(123), d: &scanraw, wantraw: RawBytes("123")},
	 119  		{s: 1.5, d: &scanraw, wantraw: RawBytes("1.5")},
	 120  		// time.Time has been placed here to check that the RawBytes slice gets
	 121  		// correctly reset when calling time.Time.AppendFormat.
	 122  		{s: time.Unix(2, 5).UTC(), d: &scanraw, wantraw: RawBytes("1970-01-01T00:00:02.000000005Z")},
	 123  
	 124  		// Strings to integers
	 125  		{s: "255", d: &scanuint8, wantuint: 255},
	 126  		{s: "256", d: &scanuint8, wanterr: "converting driver.Value type string (\"256\") to a uint8: value out of range"},
	 127  		{s: "256", d: &scanuint16, wantuint: 256},
	 128  		{s: "-1", d: &scanint, wantint: -1},
	 129  		{s: "foo", d: &scanint, wanterr: "converting driver.Value type string (\"foo\") to a int: invalid syntax"},
	 130  
	 131  		// int64 to smaller integers
	 132  		{s: int64(5), d: &scanuint8, wantuint: 5},
	 133  		{s: int64(256), d: &scanuint8, wanterr: "converting driver.Value type int64 (\"256\") to a uint8: value out of range"},
	 134  		{s: int64(256), d: &scanuint16, wantuint: 256},
	 135  		{s: int64(65536), d: &scanuint16, wanterr: "converting driver.Value type int64 (\"65536\") to a uint16: value out of range"},
	 136  
	 137  		// True bools
	 138  		{s: true, d: &scanbool, wantbool: true},
	 139  		{s: "True", d: &scanbool, wantbool: true},
	 140  		{s: "TRUE", d: &scanbool, wantbool: true},
	 141  		{s: "1", d: &scanbool, wantbool: true},
	 142  		{s: 1, d: &scanbool, wantbool: true},
	 143  		{s: int64(1), d: &scanbool, wantbool: true},
	 144  		{s: uint16(1), d: &scanbool, wantbool: true},
	 145  
	 146  		// False bools
	 147  		{s: false, d: &scanbool, wantbool: false},
	 148  		{s: "false", d: &scanbool, wantbool: false},
	 149  		{s: "FALSE", d: &scanbool, wantbool: false},
	 150  		{s: "0", d: &scanbool, wantbool: false},
	 151  		{s: 0, d: &scanbool, wantbool: false},
	 152  		{s: int64(0), d: &scanbool, wantbool: false},
	 153  		{s: uint16(0), d: &scanbool, wantbool: false},
	 154  
	 155  		// Not bools
	 156  		{s: "yup", d: &scanbool, wanterr: `sql/driver: couldn't convert "yup" into type bool`},
	 157  		{s: 2, d: &scanbool, wanterr: `sql/driver: couldn't convert 2 into type bool`},
	 158  
	 159  		// Floats
	 160  		{s: float64(1.5), d: &scanf64, wantf64: float64(1.5)},
	 161  		{s: int64(1), d: &scanf64, wantf64: float64(1)},
	 162  		{s: float64(1.5), d: &scanf32, wantf32: float32(1.5)},
	 163  		{s: "1.5", d: &scanf32, wantf32: float32(1.5)},
	 164  		{s: "1.5", d: &scanf64, wantf64: float64(1.5)},
	 165  
	 166  		// Pointers
	 167  		{s: interface{}(nil), d: &scanptr, wantnil: true},
	 168  		{s: int64(42), d: &scanptr, wantptr: &answer},
	 169  
	 170  		// To interface{}
	 171  		{s: float64(1.5), d: &scaniface, wantiface: float64(1.5)},
	 172  		{s: int64(1), d: &scaniface, wantiface: int64(1)},
	 173  		{s: "str", d: &scaniface, wantiface: "str"},
	 174  		{s: []byte("byteslice"), d: &scaniface, wantiface: []byte("byteslice")},
	 175  		{s: true, d: &scaniface, wantiface: true},
	 176  		{s: nil, d: &scaniface},
	 177  		{s: []byte(nil), d: &scaniface, wantiface: []byte(nil)},
	 178  
	 179  		// To a user-defined type
	 180  		{s: 1.5, d: new(userDefined), wantusrdef: 1.5},
	 181  		{s: int64(123), d: new(userDefined), wantusrdef: 123},
	 182  		{s: "1.5", d: new(userDefined), wantusrdef: 1.5},
	 183  		{s: []byte{1, 2, 3}, d: new(userDefinedSlice), wanterr: `unsupported Scan, storing driver.Value type []uint8 into type *sql.userDefinedSlice`},
	 184  		{s: "str", d: new(userDefinedString), wantusrstr: "str"},
	 185  
	 186  		// Other errors
	 187  		{s: complex(1, 2), d: &scanstr, wanterr: `unsupported Scan, storing driver.Value type complex128 into type *string`},
	 188  	}
	 189  }
	 190  
	 191  func intPtrValue(intptr interface{}) interface{} {
	 192  	return reflect.Indirect(reflect.Indirect(reflect.ValueOf(intptr))).Int()
	 193  }
	 194  
	 195  func intValue(intptr interface{}) int64 {
	 196  	return reflect.Indirect(reflect.ValueOf(intptr)).Int()
	 197  }
	 198  
	 199  func uintValue(intptr interface{}) uint64 {
	 200  	return reflect.Indirect(reflect.ValueOf(intptr)).Uint()
	 201  }
	 202  
	 203  func float64Value(ptr interface{}) float64 {
	 204  	return *(ptr.(*float64))
	 205  }
	 206  
	 207  func float32Value(ptr interface{}) float32 {
	 208  	return *(ptr.(*float32))
	 209  }
	 210  
	 211  func timeValue(ptr interface{}) time.Time {
	 212  	return *(ptr.(*time.Time))
	 213  }
	 214  
	 215  func TestConversions(t *testing.T) {
	 216  	for n, ct := range conversionTests() {
	 217  		err := convertAssign(ct.d, ct.s)
	 218  		errstr := ""
	 219  		if err != nil {
	 220  			errstr = err.Error()
	 221  		}
	 222  		errf := func(format string, args ...interface{}) {
	 223  			base := fmt.Sprintf("convertAssign #%d: for %v (%T) -> %T, ", n, ct.s, ct.s, ct.d)
	 224  			t.Errorf(base+format, args...)
	 225  		}
	 226  		if errstr != ct.wanterr {
	 227  			errf("got error %q, want error %q", errstr, ct.wanterr)
	 228  		}
	 229  		if ct.wantstr != "" && ct.wantstr != scanstr {
	 230  			errf("want string %q, got %q", ct.wantstr, scanstr)
	 231  		}
	 232  		if ct.wantbytes != nil && string(ct.wantbytes) != string(scanbytes) {
	 233  			errf("want byte %q, got %q", ct.wantbytes, scanbytes)
	 234  		}
	 235  		if ct.wantraw != nil && string(ct.wantraw) != string(scanraw) {
	 236  			errf("want RawBytes %q, got %q", ct.wantraw, scanraw)
	 237  		}
	 238  		if ct.wantint != 0 && ct.wantint != intValue(ct.d) {
	 239  			errf("want int %d, got %d", ct.wantint, intValue(ct.d))
	 240  		}
	 241  		if ct.wantuint != 0 && ct.wantuint != uintValue(ct.d) {
	 242  			errf("want uint %d, got %d", ct.wantuint, uintValue(ct.d))
	 243  		}
	 244  		if ct.wantf32 != 0 && ct.wantf32 != float32Value(ct.d) {
	 245  			errf("want float32 %v, got %v", ct.wantf32, float32Value(ct.d))
	 246  		}
	 247  		if ct.wantf64 != 0 && ct.wantf64 != float64Value(ct.d) {
	 248  			errf("want float32 %v, got %v", ct.wantf64, float64Value(ct.d))
	 249  		}
	 250  		if bp, boolTest := ct.d.(*bool); boolTest && *bp != ct.wantbool && ct.wanterr == "" {
	 251  			errf("want bool %v, got %v", ct.wantbool, *bp)
	 252  		}
	 253  		if !ct.wanttime.IsZero() && !ct.wanttime.Equal(timeValue(ct.d)) {
	 254  			errf("want time %v, got %v", ct.wanttime, timeValue(ct.d))
	 255  		}
	 256  		if ct.wantnil && *ct.d.(**int64) != nil {
	 257  			errf("want nil, got %v", intPtrValue(ct.d))
	 258  		}
	 259  		if ct.wantptr != nil {
	 260  			if *ct.d.(**int64) == nil {
	 261  				errf("want pointer to %v, got nil", *ct.wantptr)
	 262  			} else if *ct.wantptr != intPtrValue(ct.d) {
	 263  				errf("want pointer to %v, got %v", *ct.wantptr, intPtrValue(ct.d))
	 264  			}
	 265  		}
	 266  		if ifptr, ok := ct.d.(*interface{}); ok {
	 267  			if !reflect.DeepEqual(ct.wantiface, scaniface) {
	 268  				errf("want interface %#v, got %#v", ct.wantiface, scaniface)
	 269  				continue
	 270  			}
	 271  			if srcBytes, ok := ct.s.([]byte); ok {
	 272  				dstBytes := (*ifptr).([]byte)
	 273  				if len(srcBytes) > 0 && &dstBytes[0] == &srcBytes[0] {
	 274  					errf("copy into interface{} didn't copy []byte data")
	 275  				}
	 276  			}
	 277  		}
	 278  		if ct.wantusrdef != 0 && ct.wantusrdef != *ct.d.(*userDefined) {
	 279  			errf("want userDefined %f, got %f", ct.wantusrdef, *ct.d.(*userDefined))
	 280  		}
	 281  		if len(ct.wantusrstr) != 0 && ct.wantusrstr != *ct.d.(*userDefinedString) {
	 282  			errf("want userDefined %q, got %q", ct.wantusrstr, *ct.d.(*userDefinedString))
	 283  		}
	 284  	}
	 285  }
	 286  
	 287  func TestNullString(t *testing.T) {
	 288  	var ns NullString
	 289  	convertAssign(&ns, []byte("foo"))
	 290  	if !ns.Valid {
	 291  		t.Errorf("expecting not null")
	 292  	}
	 293  	if ns.String != "foo" {
	 294  		t.Errorf("expecting foo; got %q", ns.String)
	 295  	}
	 296  	convertAssign(&ns, nil)
	 297  	if ns.Valid {
	 298  		t.Errorf("expecting null on nil")
	 299  	}
	 300  	if ns.String != "" {
	 301  		t.Errorf("expecting blank on nil; got %q", ns.String)
	 302  	}
	 303  }
	 304  
	 305  type valueConverterTest struct {
	 306  	c			 driver.ValueConverter
	 307  	in, out interface{}
	 308  	err		 string
	 309  }
	 310  
	 311  var valueConverterTests = []valueConverterTest{
	 312  	{driver.DefaultParameterConverter, NullString{"hi", true}, "hi", ""},
	 313  	{driver.DefaultParameterConverter, NullString{"", false}, nil, ""},
	 314  }
	 315  
	 316  func TestValueConverters(t *testing.T) {
	 317  	for i, tt := range valueConverterTests {
	 318  		out, err := tt.c.ConvertValue(tt.in)
	 319  		goterr := ""
	 320  		if err != nil {
	 321  			goterr = err.Error()
	 322  		}
	 323  		if goterr != tt.err {
	 324  			t.Errorf("test %d: %T(%T(%v)) error = %q; want error = %q",
	 325  				i, tt.c, tt.in, tt.in, goterr, tt.err)
	 326  		}
	 327  		if tt.err != "" {
	 328  			continue
	 329  		}
	 330  		if !reflect.DeepEqual(out, tt.out) {
	 331  			t.Errorf("test %d: %T(%T(%v)) = %v (%T); want %v (%T)",
	 332  				i, tt.c, tt.in, tt.in, out, out, tt.out, tt.out)
	 333  		}
	 334  	}
	 335  }
	 336  
	 337  // Tests that assigning to RawBytes doesn't allocate (and also works).
	 338  func TestRawBytesAllocs(t *testing.T) {
	 339  	var tests = []struct {
	 340  		name string
	 341  		in	 interface{}
	 342  		want string
	 343  	}{
	 344  		{"uint64", uint64(12345678), "12345678"},
	 345  		{"uint32", uint32(1234), "1234"},
	 346  		{"uint16", uint16(12), "12"},
	 347  		{"uint8", uint8(1), "1"},
	 348  		{"uint", uint(123), "123"},
	 349  		{"int", int(123), "123"},
	 350  		{"int8", int8(1), "1"},
	 351  		{"int16", int16(12), "12"},
	 352  		{"int32", int32(1234), "1234"},
	 353  		{"int64", int64(12345678), "12345678"},
	 354  		{"float32", float32(1.5), "1.5"},
	 355  		{"float64", float64(64), "64"},
	 356  		{"bool", false, "false"},
	 357  		{"time", time.Unix(2, 5).UTC(), "1970-01-01T00:00:02.000000005Z"},
	 358  	}
	 359  
	 360  	buf := make(RawBytes, 10)
	 361  	test := func(name string, in interface{}, want string) {
	 362  		if err := convertAssign(&buf, in); err != nil {
	 363  			t.Fatalf("%s: convertAssign = %v", name, err)
	 364  		}
	 365  		match := len(buf) == len(want)
	 366  		if match {
	 367  			for i, b := range buf {
	 368  				if want[i] != b {
	 369  					match = false
	 370  					break
	 371  				}
	 372  			}
	 373  		}
	 374  		if !match {
	 375  			t.Fatalf("%s: got %q (len %d); want %q (len %d)", name, buf, len(buf), want, len(want))
	 376  		}
	 377  	}
	 378  
	 379  	n := testing.AllocsPerRun(100, func() {
	 380  		for _, tt := range tests {
	 381  			test(tt.name, tt.in, tt.want)
	 382  		}
	 383  	})
	 384  
	 385  	// The numbers below are only valid for 64-bit interface word sizes,
	 386  	// and gc. With 32-bit words there are more convT2E allocs, and
	 387  	// with gccgo, only pointers currently go in interface data.
	 388  	// So only care on amd64 gc for now.
	 389  	measureAllocs := runtime.GOARCH == "amd64" && runtime.Compiler == "gc"
	 390  
	 391  	if n > 0.5 && measureAllocs {
	 392  		t.Fatalf("allocs = %v; want 0", n)
	 393  	}
	 394  
	 395  	// This one involves a convT2E allocation, string -> interface{}
	 396  	n = testing.AllocsPerRun(100, func() {
	 397  		test("string", "foo", "foo")
	 398  	})
	 399  	if n > 1.5 && measureAllocs {
	 400  		t.Fatalf("allocs = %v; want max 1", n)
	 401  	}
	 402  }
	 403  
	 404  // https://golang.org/issues/13905
	 405  func TestUserDefinedBytes(t *testing.T) {
	 406  	type userDefinedBytes []byte
	 407  	var u userDefinedBytes
	 408  	v := []byte("foo")
	 409  
	 410  	convertAssign(&u, v)
	 411  	if &u[0] == &v[0] {
	 412  		t.Fatal("userDefinedBytes got potentially dirty driver memory")
	 413  	}
	 414  }
	 415  
	 416  type Valuer_V string
	 417  
	 418  func (v Valuer_V) Value() (driver.Value, error) {
	 419  	return strings.ToUpper(string(v)), nil
	 420  }
	 421  
	 422  type Valuer_P string
	 423  
	 424  func (p *Valuer_P) Value() (driver.Value, error) {
	 425  	if p == nil {
	 426  		return "nil-to-str", nil
	 427  	}
	 428  	return strings.ToUpper(string(*p)), nil
	 429  }
	 430  
	 431  func TestDriverArgs(t *testing.T) {
	 432  	var nilValuerVPtr *Valuer_V
	 433  	var nilValuerPPtr *Valuer_P
	 434  	var nilStrPtr *string
	 435  	tests := []struct {
	 436  		args []interface{}
	 437  		want []driver.NamedValue
	 438  	}{
	 439  		0: {
	 440  			args: []interface{}{Valuer_V("foo")},
	 441  			want: []driver.NamedValue{
	 442  				{
	 443  					Ordinal: 1,
	 444  					Value:	 "FOO",
	 445  				},
	 446  			},
	 447  		},
	 448  		1: {
	 449  			args: []interface{}{nilValuerVPtr},
	 450  			want: []driver.NamedValue{
	 451  				{
	 452  					Ordinal: 1,
	 453  					Value:	 nil,
	 454  				},
	 455  			},
	 456  		},
	 457  		2: {
	 458  			args: []interface{}{nilValuerPPtr},
	 459  			want: []driver.NamedValue{
	 460  				{
	 461  					Ordinal: 1,
	 462  					Value:	 "nil-to-str",
	 463  				},
	 464  			},
	 465  		},
	 466  		3: {
	 467  			args: []interface{}{"plain-str"},
	 468  			want: []driver.NamedValue{
	 469  				{
	 470  					Ordinal: 1,
	 471  					Value:	 "plain-str",
	 472  				},
	 473  			},
	 474  		},
	 475  		4: {
	 476  			args: []interface{}{nilStrPtr},
	 477  			want: []driver.NamedValue{
	 478  				{
	 479  					Ordinal: 1,
	 480  					Value:	 nil,
	 481  				},
	 482  			},
	 483  		},
	 484  	}
	 485  	for i, tt := range tests {
	 486  		ds := &driverStmt{Locker: &sync.Mutex{}, si: stubDriverStmt{nil}}
	 487  		got, err := driverArgsConnLocked(nil, ds, tt.args)
	 488  		if err != nil {
	 489  			t.Errorf("test[%d]: %v", i, err)
	 490  			continue
	 491  		}
	 492  		if !reflect.DeepEqual(got, tt.want) {
	 493  			t.Errorf("test[%d]: got %v, want %v", i, got, tt.want)
	 494  		}
	 495  	}
	 496  }
	 497  
	 498  type dec struct {
	 499  	form				byte
	 500  	neg				 bool
	 501  	coefficient [16]byte
	 502  	exponent		int32
	 503  }
	 504  
	 505  func (d dec) Decompose(buf []byte) (form byte, negative bool, coefficient []byte, exponent int32) {
	 506  	coef := make([]byte, 16)
	 507  	copy(coef, d.coefficient[:])
	 508  	return d.form, d.neg, coef, d.exponent
	 509  }
	 510  
	 511  func (d *dec) Compose(form byte, negative bool, coefficient []byte, exponent int32) error {
	 512  	switch form {
	 513  	default:
	 514  		return fmt.Errorf("unknown form %d", form)
	 515  	case 1, 2:
	 516  		d.form = form
	 517  		d.neg = negative
	 518  		return nil
	 519  	case 0:
	 520  	}
	 521  	d.form = form
	 522  	d.neg = negative
	 523  	d.exponent = exponent
	 524  
	 525  	// This isn't strictly correct, as the extra bytes could be all zero,
	 526  	// ignore this for this test.
	 527  	if len(coefficient) > 16 {
	 528  		return fmt.Errorf("coefficient too large")
	 529  	}
	 530  	copy(d.coefficient[:], coefficient)
	 531  
	 532  	return nil
	 533  }
	 534  
	 535  type decFinite struct {
	 536  	neg				 bool
	 537  	coefficient [16]byte
	 538  	exponent		int32
	 539  }
	 540  
	 541  func (d decFinite) Decompose(buf []byte) (form byte, negative bool, coefficient []byte, exponent int32) {
	 542  	coef := make([]byte, 16)
	 543  	copy(coef, d.coefficient[:])
	 544  	return 0, d.neg, coef, d.exponent
	 545  }
	 546  
	 547  func (d *decFinite) Compose(form byte, negative bool, coefficient []byte, exponent int32) error {
	 548  	switch form {
	 549  	default:
	 550  		return fmt.Errorf("unknown form %d", form)
	 551  	case 1, 2:
	 552  		return fmt.Errorf("unsupported form %d", form)
	 553  	case 0:
	 554  	}
	 555  	d.neg = negative
	 556  	d.exponent = exponent
	 557  
	 558  	// This isn't strictly correct, as the extra bytes could be all zero,
	 559  	// ignore this for this test.
	 560  	if len(coefficient) > 16 {
	 561  		return fmt.Errorf("coefficient too large")
	 562  	}
	 563  	copy(d.coefficient[:], coefficient)
	 564  
	 565  	return nil
	 566  }
	 567  
	 568  func TestDecimal(t *testing.T) {
	 569  	list := []struct {
	 570  		name string
	 571  		in	 decimalDecompose
	 572  		out	dec
	 573  		err	bool
	 574  	}{
	 575  		{name: "same", in: dec{exponent: -6}, out: dec{exponent: -6}},
	 576  
	 577  		// Ensure reflection is not used to assign the value by using different types.
	 578  		{name: "diff", in: decFinite{exponent: -6}, out: dec{exponent: -6}},
	 579  
	 580  		{name: "bad-form", in: dec{form: 200}, err: true},
	 581  	}
	 582  	for _, item := range list {
	 583  		t.Run(item.name, func(t *testing.T) {
	 584  			out := dec{}
	 585  			err := convertAssign(&out, item.in)
	 586  			if item.err {
	 587  				if err == nil {
	 588  					t.Fatalf("unexpected nil error")
	 589  				}
	 590  				return
	 591  			}
	 592  			if err != nil {
	 593  				t.Fatalf("unexpected error: %v", err)
	 594  			}
	 595  			if !reflect.DeepEqual(out, item.out) {
	 596  				t.Fatalf("got %#v want %#v", out, item.out)
	 597  			}
	 598  		})
	 599  	}
	 600  }
	 601  

View as plain text