aboutsummaryrefslogtreecommitdiff
path: root/bitreader.go
blob: 3fbc2219b4a54bd328bcf934bffd429d8c80e2be (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
// BitReader is a simple bit reader with big/little-endian support for golang.
// It can read stream data from an io.Reader; can read from os.File and a byte array with bytes.NewReader(array).
// Uses bitwise operations for v2.
// Supports reading up to 64 bits at one time.
// Includes wrapper functions for most used data types.
// Error checking on all but wrapper functions.
// Thanks to github.com/mlugg for the big help!
package bitreader

import (
	"fmt"
	"io"
	"math"
)

// ReaderType is the main structure of our Reader.
// Whenever index == 0, we need to read a new byte from stream into curByte
type ReaderType struct {
	stream  io.Reader // The underlying stream we're reading bytes from
	index   uint8     // The current index into the byte [0-7]
	curByte byte      // The byte we're currently reading from
	le      bool      // Whether to read in little-endian order
}

// Reader is the main constructor that creates the ReaderType object
// with stream data and little-endian state.
func Reader(stream io.Reader, le bool) *ReaderType {
	return &ReaderType{
		stream:  stream,
		index:   0,
		curByte: 0, // Initial value doesn't matter, it'll be read as soon as we try to read any bits
		le:      le,
	}
}

// TryReadBool is a wrapper function that gets the state of 1-bit,
// returns true if 1, false if 0. Panics on error.
func (reader *ReaderType) TryReadBool() bool {
	flag, err := reader.ReadBool()
	if err != nil {
		panic(err)
	}
	return flag
}

// TryReadInt1 is a wrapper function that returns the value of 1-bit.
// Returns type uint8. Panics on error.
func (reader *ReaderType) TryReadInt1() uint8 {
	value, err := reader.ReadBits(1)
	if err != nil {
		panic(err)
	}
	return uint8(value)
}

// TryReadInt8 is a wrapper function that returns the value of 8-bits.
// Returns uint8. Panics on error.
func (reader *ReaderType) TryReadInt8() uint8 {
	value, err := reader.ReadBits(8)
	if err != nil {
		panic(err)
	}
	return uint8(value)
}

// TryReadInt16 is a wrapper function that returns the value of 16-bits.
// Returns uint16. Panics on error.
func (reader *ReaderType) TryReadInt16() uint16 {
	value, err := reader.ReadBits(16)
	if err != nil {
		panic(err)
	}
	return uint16(value)
}

// TryReadInt32 is a wrapper function that returns the value of 32-bits.
// Returns uint32. Panics on error.
func (reader *ReaderType) TryReadInt32() uint32 {
	value, err := reader.ReadBits(32)
	if err != nil {
		panic(err)
	}
	return uint32(value)
}

// TryReadInt64 is a wrapper function that returns the value of 64-bits.
// Returns uint64. Panics on error.
func (reader *ReaderType) TryReadInt64() uint64 {
	value, err := reader.ReadBits(64)
	if err != nil {
		panic(err)
	}
	return value
}

// TryReadFloat32 is a wrapper function that returns the value of 32-bits.
// Returns float32. Panics on error.
func (reader *ReaderType) TryReadFloat32() float32 {
	value, err := reader.ReadBits(32)
	if err != nil {
		panic(err)
	}
	return math.Float32frombits(uint32(value))
}

// TryReadFloat64 is a wrapper function that returns the value of 64-bits.
// Returns float64. Panics on error.
func (reader *ReaderType) TryReadFloat64() float64 {
	value, err := reader.ReadBits(64)
	if err != nil {
		panic(err)
	}
	return math.Float64frombits(value)
}

// SkipBits is a function that increases Reader index
// based on given input bits number. Returns an error
// if there are no remaining bits.
func (reader *ReaderType) SkipBits(bits int) error {
	// Read as many raw bytes as we can
	bytes := bits / 8
	buf := make([]byte, bytes)
	_, err := reader.stream.Read(buf)
	if err != nil {
		return err
	}
	// The final read byte should be the new current byte
	if bytes > 0 {
		reader.curByte = buf[bytes-1]
	}
	// Read the extra bits
	for i := bytes * 8; i < bits; i++ {
		_, err := reader.readBit()
		if err != nil {
			return err
		}
	}
	return nil
}

// SkipBytes is a function that increases Reader index
// based on given input bytes number. Returns an error
// if there are no remaining bits.
func (reader *ReaderType) SkipBytes(bytes int) error {
	err := reader.SkipBits(bytes * 8)
	if err != nil {
		return err
	}
	return nil
}

// ReadBits is a function that reads the specified amount of bits
// specified in the parameter and returns the value, error
// based on the output. It can read up to 64 bits. Returns the read
// value in type uint64.
//
// Returns an error if there are no remaining bits.
func (reader *ReaderType) ReadBits(bits int) (uint64, error) {
	if bits < 1 || bits > 64 {
		return 0, fmt.Errorf("ReadBits(bits) ERROR: Bits number should be between 1 and 64.")
	}
	var val uint64
	for i := 0; i < bits; i++ {
		bit, err := reader.readBit()
		if err != nil {
			return 0, err
		}

		if reader.le {
			val |= uint64(bit) << i
		} else {
			val |= uint64(bit) << (bits - 1 - i)
		}
	}
	return val, nil
}

// ReadBytes is a function that reads the specified amount of bytes
// specified in the parameter and returns the value, error
// based on the output. It can read up to 8 bytes. Returns the read
// value in type uint64.
//
// Returns an error if there are no remaining bits.
func (reader *ReaderType) ReadBytes(bytes int) (uint64, error) {
	if bytes < 1 || bytes > 8 {
		return 0, fmt.Errorf("ReadBytes(bytes) ERROR: Bytes number should be between 1 and 8.")
	}
	value, err := reader.ReadBits(bytes * 8)
	if err != nil {
		return 0, err
	}
	return value, nil
}

// ReadBool is a function that reads one bit and returns the state, error
// based on the output. Returns the read value in a bool format.
//
// Returns an error if there are no remaining bits.
func (reader *ReaderType) ReadBool() (bool, error) {
	val, err := reader.readBit()
	if err != nil {
		return false, err
	}
	return val == 1, nil
}

// readBit is a private function that reads a single bit from the stream.
// This is the main function that makes us read stream data.
func (reader *ReaderType) readBit() (uint8, error) {
	if reader.index == 0 {
		// Read a byte from stream into curByte
		buf := make([]byte, 1)
		_, err := reader.stream.Read(buf)
		if err != nil {
			return 0, err
		}
		reader.curByte = buf[0]
	}
	var val bool
	if reader.le {
		val = (reader.curByte & (1 << reader.index)) != 0
	} else {
		val = (reader.curByte & (1 << (7 - reader.index))) != 0
	}
	reader.index = (reader.index + 1) % 8
	if val {
		return 1, nil
	} else {
		return 0, nil
	}
}