Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
I
image-compression
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Elphel
image-compression
Commits
bbb3dc00
Commit
bbb3dc00
authored
Mar 08, 2022
by
Kelly Chang
Browse files
Options
Browse Files
Download
Plain Diff
Merge branch 'master' of
https://git.elphel.com/nathaniel/image-compression
parents
1fcec44c
269f249e
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
0 additions
and
822 deletions
+0
-822
adaptive_arithmetic_compress.py
adaptive_arithmetic_compress.py
+0
-57
arithmetic.py
arithmetic.py
+0
-167
arithmeticcoding.py
arithmeticcoding.py
+0
-598
No files found.
adaptive_arithmetic_compress.py
deleted
100644 → 0
View file @
1fcec44c
#
# Compression application using adaptive arithmetic coding
#
# Usage: python adaptive-arithmetic-compress.py InputFile OutputFile
# Then use the corresponding adaptive-arithmetic-decompress.py application to recreate the original input file.
# Note that the application starts with a flat frequency table of 257 symbols (all set to a frequency of 1),
# and updates it after each byte encoded. The corresponding decompressor program also starts with a flat
# frequency table and updates it after each byte decoded. It is by design that the compressor and
# decompressor have synchronized states, so that the data can be decompressed properly.
#
# Copyright (c) Project Nayuki
#
# https://www.nayuki.io/page/reference-arithmetic-coding
# https://github.com/nayuki/Reference-arithmetic-coding
#
import
sys
import
arithmeticcoding
python3
=
sys
.
version_info
.
major
>=
3
# Command line main application function.
def
main
(
args
):
# Handle command line arguments
if
len
(
args
)
!=
2
:
sys
.
exit
(
"Usage: python adaptive-arithmetic-compress.py InputFile OutputFile"
)
inputfile
=
args
[
0
]
outputfile
=
args
[
1
]
# Perform file compression
with
open
(
inputfile
,
"rb"
)
as
inp
:
bitout
=
arithmeticcoding
.
BitOutputStream
(
open
(
outputfile
,
"wb"
))
try
:
compress
(
inp
,
bitout
)
finally
:
bitout
.
close
()
def
compress
(
inp
,
bitout
):
initfreqs
=
arithmeticcoding
.
FlatFrequencyTable
(
257
)
freqs
=
arithmeticcoding
.
SimpleFrequencyTable
(
initfreqs
)
enc
=
arithmeticcoding
.
ArithmeticEncoder
(
bitout
)
while
True
:
# Read and encode one byte
symbol
=
inp
.
read
(
1
)
if
len
(
symbol
)
==
0
:
break
symbol
=
symbol
[
0
]
if
python3
else
ord
(
symbol
)
enc
.
write
(
freqs
,
symbol
)
freqs
.
increment
(
symbol
)
enc
.
write
(
freqs
,
256
)
# EOF
enc
.
finish
()
# Flush remaining code bits
# Main launcher
if
__name__
==
"__main__"
:
main
(
sys
.
argv
[
1
:
])
arithmetic.py
deleted
100644 → 0
View file @
1fcec44c
# Example implementation of simple arithmetic coding in Python (2.7+).
#
# USAGE
#
# python -i arithmetic.py
# >>> m = {'a': 1, 'b': 1, 'c': 1}
# >>> model = dirichlet(m)
# >>> encode(model, "aabbaacc")
# '00011110011110010'
#
# NOTES
#
# This implementation has many shortcomings, e.g.,
# - There are several inefficient tests, loops, and conversions
# - There are a few places where code is uncessarily duplicated
# - It does not output the coded message as a stream
# - It can only code short messages due to machine precision
# - The is no defensive coding against errors (e.g., out-of-model symbols)
# - I've not implemented a decoder!
#
# The aim was to make the implementation here as close as possible to
# the algorithm described in lectures while giving some extra detail about
# routines such as finding extensions to binary intervals.
#
# For a more sophisticated implementation, please refer to:
#
# "Arithmetic Coding for Data Compression"
# I. H. Witten, R. M. Neal, and J. G. Cleary
# Communications of the ACM, Col. 30 (6), 1987
#
# AUTHOR: Mark Reid
# CREATED: 2014-09-30
def
encode
(
G
,
stream
):
'''
Arithmetically encodes the given stream using the guesser function G
which returns probabilities over symbols P(x|xs) given a sequence xs.
'''
u
,
v
=
0.0
,
1.0
# The interval [u, v) for the message
xs
,
bs
=
""
,
""
# The message xs, and binary code bs
p
=
G
(
xs
)
# Compute the initial distribution over symbols
# Iterate through stream, repeatedly finding the longest binary code
# that surrounds the interval for the message so far
for
x
in
stream
:
# Record the new symbol
xs
+=
x
# Find the interval for the message so far
F_lo
,
F_hi
=
cdf_interval
(
p
,
x
)
u
,
v
=
u
+
(
v
-
u
)
*
F_lo
,
u
+
(
v
-
u
)
*
F_hi
# Find a binary code whose interval surrounds [u,v)
bs
=
extend_around
(
bs
,
u
,
v
)
# Update the symbol probabilities
p
=
G
(
xs
)
# Stream finished so find shortest extension of the code that sits inside
# the top half of [u, v)
bs
=
extend_inside
(
bs
,
u
+
(
v
-
u
)
/
2
,
v
)
return
bs
##############################################################################
# Models
def
dirichlet
(
m
):
'''
Returns a Dirichlet model (as a function) for probabilities with
prior counts given by the symbol to count dictionary m.
Probabilities returned by the returned functions are (symbol, prob)
dictionaries.
'''
# Build a function that returns P(x|xs) based on the priors in m
# and the counts of the symbols in xs
def
p
(
xs
):
counts
=
m
.
copy
()
for
x
in
xs
:
counts
[
x
]
+=
1
total
=
sum
(
counts
.
values
())
return
{
a
:
float
(
c
)
/
total
for
a
,
c
in
counts
.
items
()
}
# Return the constructed function
return
p
##############################################################################
# Interval methods
def
cdf_interval
(
p
,
a
):
'''
Compute the cumulative distribution interval [F(a'), F(a)) for the
probabilities p (represented as a (symbol,prob) dict) where
F(a) = P(x <= a) and a' is the symbol preceeding a.
'''
F_lo
,
F_hi
=
0
,
0
A
=
sorted
(
p
)
for
x
in
A
:
F_lo
,
F_hi
=
F_hi
,
F_hi
+
p
[
x
]
if
x
==
a
:
break
return
F_lo
,
F_hi
def
binary_interval
(
bs
):
'''
Returns an interval [n, m) for n and m integers, and denominator d
representing the interval [n/d, m/d) for the binary string bs.
'''
n
,
d
=
to_rational
(
bs
)
return
n
,
n
+
1
,
d
def
to_rational
(
bs
):
'''Return numerator and denominator for ratio of 0.bs.'''
n
=
0
for
b
in
bs
:
n
*=
2
n
+=
int
(
b
)
return
n
,
2
**
len
(
bs
)
def
around
(
bs
,
u
,
v
):
'''Tests whether [0.bs, 0.bs111...) contains [u, v).'''
n
,
m
,
d
=
binary_interval
(
bs
)
return
(
n
<=
u
*
d
)
and
(
v
*
d
<=
m
)
def
extend_around
(
bs
,
u
,
v
):
'''Find the longest extension of the given binary string so its interval
wraps around the interval [u, v).'''
contained
=
True
while
contained
:
if
around
(
bs
+
"0"
,
u
,
v
):
bs
+=
"0"
elif
around
(
bs
+
"1"
,
u
,
v
):
bs
+=
"1"
else
:
contained
=
False
return
bs
def
inside
(
bs
,
u
,
v
):
'''Tests whether [0.bs, 0.bs111...) is contained by [u, v).'''
n
,
m
,
d
=
binary_interval
(
bs
)
return
(
u
*
d
<=
n
)
and
(
m
<=
v
*
d
)
def
extend_inside
(
bs
,
u
,
v
):
'''Find the shortest extension of the given binary string so its interval
sits inside the interval [u, v).'''
while
not
inside
(
bs
,
u
,
v
):
# Test whether gap between binary interval and [u,v) is bigger at the
# bottom than at the top
n
,
m
,
d
=
binary_interval
(
bs
)
if
u
*
d
-
n
>
m
-
v
*
d
:
bs
+=
"1"
# If so, move bottom up by halving
else
:
bs
+=
"0"
# If not, move top down by halving
return
bs
arithmeticcoding.py
deleted
100644 → 0
View file @
1fcec44c
#
# Reference arithmetic coding
# Copyright (c) Project Nayuki
#
# https://www.nayuki.io/page/reference-arithmetic-coding
# https://github.com/nayuki/Reference-arithmetic-coding
#
import
sys
python3
=
sys
.
version_info
.
major
>=
3
# ---- Arithmetic coding core classes ----
# Provides the state and behaviors that arithmetic coding encoders and decoders share.
class
ArithmeticCoderBase
(
object
):
# Number of bits for 'low' and 'high'. Configurable and must be at least 1.
STATE_SIZE
=
32
# Maximum range during coding (trivial), i.e. 1000...000.
MAX_RANGE
=
1
<<
STATE_SIZE
# Minimum range during coding (non-trivial), i.e. 010...010.
MIN_RANGE
=
(
MAX_RANGE
>>
2
)
+
2
# Maximum allowed total frequency at all times during coding.
MAX_TOTAL
=
MIN_RANGE
# Mask of STATE_SIZE ones, i.e. 111...111.
MASK
=
MAX_RANGE
-
1
# Mask of the top bit at width STATE_SIZE, i.e. 100...000.
TOP_MASK
=
MAX_RANGE
>>
1
# Mask of the second highest bit at width STATE_SIZE, i.e. 010...000.
SECOND_MASK
=
TOP_MASK
>>
1
# Constructs an arithmetic coder, which initializes the code range.
def
__init__
(
self
):
# Low end of this arithmetic coder's current range. Conceptually has an infinite number of trailing 0s.
self
.
low
=
0
# High end of this arithmetic coder's current range. Conceptually has an infinite number of trailing 1s.
self
.
high
=
ArithmeticCoderBase
.
MASK
# Updates the code range (low and high) of this arithmetic coder as a result
# of processing the given symbol with the given frequency table.
# Invariants that are true before and after encoding/decoding each symbol:
# - 0 <= low <= code <= high < 2^STATE_SIZE. ('code' exists only in the decoder.)
# Therefore these variables are unsigned integers of STATE_SIZE bits.
# - (low < 1/2 * 2^STATE_SIZE) && (high >= 1/2 * 2^STATE_SIZE).
# In other words, they are in different halves of the full range.
# - (low < 1/4 * 2^STATE_SIZE) || (high >= 3/4 * 2^STATE_SIZE).
# In other words, they are not both in the middle two quarters.
# - Let range = high - low + 1, then MAX_RANGE/4 < MIN_RANGE <= range
# <= MAX_RANGE = 2^STATE_SIZE. These invariants for 'range' essentially
# dictate the maximum total that the incoming frequency table can have.
def
update
(
self
,
freqs
,
symbol
):
# State check
low
=
self
.
low
high
=
self
.
high
if
low
>=
high
or
(
low
&
ArithmeticCoderBase
.
MASK
)
!=
low
or
(
high
&
ArithmeticCoderBase
.
MASK
)
!=
high
:
raise
AssertionError
(
"Low or high out of range"
)
range
=
high
-
low
+
1
if
not
(
ArithmeticCoderBase
.
MIN_RANGE
<=
range
<=
ArithmeticCoderBase
.
MAX_RANGE
):
raise
AssertionError
(
"Range out of range"
)
# Frequency table values check
total
=
freqs
.
get_total
()
symlow
=
freqs
.
get_low
(
symbol
)
symhigh
=
freqs
.
get_high
(
symbol
)
if
symlow
==
symhigh
:
raise
ValueError
(
"Symbol has zero frequency"
)
if
total
>
ArithmeticCoderBase
.
MAX_TOTAL
:
raise
ValueError
(
"Cannot code symbol because total is too large"
)
# Update range
newlow
=
low
+
symlow
*
range
//
total
newhigh
=
low
+
symhigh
*
range
//
total
-
1
self
.
low
=
newlow
self
.
high
=
newhigh
# While the highest bits are equal
while
((
self
.
low
^
self
.
high
)
&
ArithmeticCoderBase
.
TOP_MASK
)
==
0
:
self
.
shift
()
self
.
low
=
(
self
.
low
<<
1
)
&
ArithmeticCoderBase
.
MASK
self
.
high
=
((
self
.
high
<<
1
)
&
ArithmeticCoderBase
.
MASK
)
|
1
# While the second highest bit of low is 1 and the second highest bit of high is 0
while
(
self
.
low
&
~
self
.
high
&
ArithmeticCoderBase
.
SECOND_MASK
)
!=
0
:
self
.
underflow
()
self
.
low
=
(
self
.
low
<<
1
)
&
(
ArithmeticCoderBase
.
MASK
>>
1
)
self
.
high
=
((
self
.
high
<<
1
)
&
(
ArithmeticCoderBase
.
MASK
>>
1
))
|
ArithmeticCoderBase
.
TOP_MASK
|
1
# Called to handle the situation when the top bit of 'low' and 'high' are equal.
def
shift
(
self
):
raise
NotImplementedError
()
# Called to handle the situation when low=01(...) and high=10(...).
def
underflow
(
self
):
raise
NotImplementedError
()
# Encodes symbols and writes to an arithmetic-coded bit stream.
class
ArithmeticEncoder
(
ArithmeticCoderBase
):
# Constructs an arithmetic coding encoder based on the given bit output stream.
def
__init__
(
self
,
bitout
):
super
(
ArithmeticEncoder
,
self
)
.
__init__
()
# The underlying bit output stream.
self
.
output
=
bitout
# Number of saved underflow bits. This value can grow without bound.
self
.
num_underflow
=
0
# Encodes the given symbol based on the given frequency table.
# This updates this arithmetic coder's state and may write out some bits.
def
write
(
self
,
freqs
,
symbol
):
if
not
isinstance
(
freqs
,
CheckedFrequencyTable
):
freqs
=
CheckedFrequencyTable
(
freqs
)
self
.
update
(
freqs
,
symbol
)
# Terminates the arithmetic coding by flushing any buffered bits, so that the output can be decoded properly.
# It is important that this method must be called at the end of the each encoding process.
# Note that this method merely writes data to the underlying output stream but does not close it.
def
finish
(
self
):
self
.
output
.
write
(
1
)
def
shift
(
self
):
bit
=
self
.
low
>>
(
ArithmeticCoderBase
.
STATE_SIZE
-
1
)
self
.
output
.
write
(
bit
)
# Write out the saved underflow bits
for
i
in
range
(
self
.
num_underflow
):
self
.
output
.
write
(
bit
^
1
)
self
.
num_underflow
=
0
def
underflow
(
self
):
self
.
num_underflow
+=
1
# Reads from an arithmetic-coded bit stream and decodes symbols.
class
ArithmeticDecoder
(
ArithmeticCoderBase
):
# Constructs an arithmetic coding decoder based on the
# given bit input stream, and fills the code bits.
def
__init__
(
self
,
bitin
):
super
(
ArithmeticDecoder
,
self
)
.
__init__
()
# The underlying bit input stream.
self
.
input
=
bitin
# The current raw code bits being buffered, which is always in the range [low, high].
self
.
code
=
0
for
i
in
range
(
ArithmeticCoderBase
.
STATE_SIZE
):
self
.
code
=
self
.
code
<<
1
|
self
.
read_code_bit
()
# Decodes the next symbol based on the given frequency table and returns it.
# Also updates this arithmetic coder's state and may read in some bits.
def
read
(
self
,
freqs
):
if
not
isinstance
(
freqs
,
CheckedFrequencyTable
):
freqs
=
CheckedFrequencyTable
(
freqs
)
# Translate from coding range scale to frequency table scale
total
=
freqs
.
get_total
()
if
total
>
ArithmeticCoderBase
.
MAX_TOTAL
:
raise
ValueError
(
"Cannot decode symbol because total is too large"
)
range
=
self
.
high
-
self
.
low
+
1
offset
=
self
.
code
-
self
.
low
value
=
((
offset
+
1
)
*
total
-
1
)
//
range
assert
value
*
range
//
total
<=
offset
assert
0
<=
value
<
total
# A kind of binary search. Find highest symbol such that freqs.get_low(symbol) <= value.
start
=
0
end
=
freqs
.
get_symbol_limit
()
while
end
-
start
>
1
:
middle
=
(
start
+
end
)
>>
1
if
freqs
.
get_low
(
middle
)
>
value
:
end
=
middle
else
:
start
=
middle
assert
start
+
1
==
end
symbol
=
start
assert
freqs
.
get_low
(
symbol
)
*
range
//
total
<=
offset
<
freqs
.
get_high
(
symbol
)
*
range
//
total
self
.
update
(
freqs
,
symbol
)
if
not
(
self
.
low
<=
self
.
code
<=
self
.
high
):
raise
AssertionError
(
"Code out of range"
)
return
symbol
def
shift
(
self
):
self
.
code
=
((
self
.
code
<<
1
)
&
ArithmeticCoderBase
.
MASK
)
|
self
.
read_code_bit
()
def
underflow
(
self
):
self
.
code
=
(
self
.
code
&
ArithmeticCoderBase
.
TOP_MASK
)
|
((
self
.
code
<<
1
)
&
(
ArithmeticCoderBase
.
MASK
>>
1
))
|
self
.
read_code_bit
()
# Returns the next bit (0 or 1) from the input stream. The end
# of stream is treated as an infinite number of trailing zeros.
def
read_code_bit
(
self
):
temp
=
self
.
input
.
read
()
if
temp
==
-
1
:
temp
=
0
return
temp
# ---- Frequency table classes ----
# A table of symbol frequencies. The table holds data for symbols numbered from 0
# to get_symbol_limit()-1. Each symbol has a frequency, which is a non-negative integer.
# Frequency table objects are primarily used for getting cumulative symbol
# frequencies. These objects can be mutable depending on the implementation.
class
FrequencyTable
(
object
):
# Returns the number of symbols in this frequency table, which is a positive number.
def
get_symbol_limit
(
self
):
raise
NotImplementedError
()
# Returns the frequency of the given symbol. The returned value is at least 0.
def
get
(
self
,
symbol
):
raise
NotImplementedError
()
# Sets the frequency of the given symbol to the given value.
# The frequency value must be at least 0.
def
set
(
self
,
symbol
,
freq
):
raise
NotImplementedError
()
# Increments the frequency of the given symbol.
def
increment
(
self
,
symbol
):
raise
NotImplementedError
()
# Returns the total of all symbol frequencies. The returned value is at
# least 0 and is always equal to get_high(get_symbol_limit() - 1).
def
get_total
(
self
):
raise
NotImplementedError
()
# Returns the sum of the frequencies of all the symbols strictly
# below the given symbol value. The returned value is at least 0.
def
get_low
(
self
,
symbol
):
raise
NotImplementedError
()
# Returns the sum of the frequencies of the given symbol
# and all the symbols below. The returned value is at least 0.
def
get_high
(
self
,
symbol
):
raise
NotImplementedError
()
# An immutable frequency table where every symbol has the same frequency of 1.
# Useful as a fallback model when no statistics are available.
class
FlatFrequencyTable
(
FrequencyTable
):
# Constructs a flat frequency table with the given number of symbols.
def
__init__
(
self
,
numsyms
):
if
numsyms
<
1
:
raise
ValueError
(
"Number of symbols must be positive"
)
self
.
numsymbols
=
numsyms
# Total number of symbols, which is at least 1
# Returns the number of symbols in this table, which is at least 1.
def
get_symbol_limit
(
self
):
return
self
.
numsymbols
# Returns the frequency of the given symbol, which is always 1.
def
get
(
self
,
symbol
):
self
.
_check_symbol
(
symbol
)
return
1
# Returns the total of all symbol frequencies, which is
# always equal to the number of symbols in this table.
def
get_total
(
self
):
return
self
.
numsymbols
# Returns the sum of the frequencies of all the symbols strictly below
# the given symbol value. The returned value is equal to 'symbol'.
def
get_low
(
self
,
symbol
):
self
.
_check_symbol
(
symbol
)
return
symbol
# Returns the sum of the frequencies of the given symbol and all
# the symbols below. The returned value is equal to 'symbol' + 1.
def
get_high
(
self
,
symbol
):
self
.
_check_symbol
(
symbol
)
return
symbol
+
1
# Returns silently if 0 <= symbol < numsymbols, otherwise raises an exception.
def
_check_symbol
(
self
,
symbol
):
if
0
<=
symbol
<
self
.
numsymbols
:
return
else
:
raise
ValueError
(
"Symbol out of range"
)
# Returns a string representation of this frequency table. The format is subject to change.
def
__str__
(
self
):
return
"FlatFrequencyTable={}"
.
format
(
self
.
numsymbols
)
# Unsupported operation, because this frequency table is immutable.
def
set
(
self
,
symbol
,
freq
):
raise
NotImplementedError
()
# Unsupported operation, because this frequency table is immutable.
def
increment
(
self
,
symbol
):
raise
NotImplementedError
()
# A mutable table of symbol frequencies. The number of symbols cannot be changed
# after construction. The current algorithm for calculating cumulative frequencies
# takes linear time, but there exist faster algorithms such as Fenwick trees.
class
SimpleFrequencyTable
(
FrequencyTable
):
# Constructs a simple frequency table in one of two ways:
# - SimpleFrequencyTable(sequence):
# Builds a frequency table from the given sequence of symbol frequencies.
# There must be at least 1 symbol, and no symbol has a negative frequency.
# - SimpleFrequencyTable(freqtable):
# Builds a frequency table by copying the given frequency table.
def
__init__
(
self
,
freqs
):
if
isinstance
(
freqs
,
FrequencyTable
):
numsym
=
freqs
.
get_symbol_limit
()
self
.
frequencies
=
[
freqs
.
get
(
i
)
for
i
in
range
(
numsym
)]
else
:
# Assume it is a sequence type
self
.
frequencies
=
list
(
freqs
)
# Make copy
# 'frequencies' is a list of the frequency for each symbol.
# Its length is at least 1, and each element is non-negative.
if
len
(
self
.
frequencies
)
<
1
:
raise
ValueError
(
"At least 1 symbol needed"
)
for
freq
in
self
.
frequencies
:
if
freq
<
0
:
raise
ValueError
(
"Negative frequency"
)
# Always equal to the sum of 'frequencies'
self
.
total
=
sum
(
self
.
frequencies
)
# cumulative[i] is the sum of 'frequencies' from 0 (inclusive) to i (exclusive).
# Initialized lazily. When it is not None, the data is valid.
self
.
cumulative
=
None
# Returns the number of symbols in this frequency table, which is at least 1.
def
get_symbol_limit
(
self
):
return
len
(
self
.
frequencies
)
# Returns the frequency of the given symbol. The returned value is at least 0.
def
get
(
self
,
symbol
):
self
.
_check_symbol
(
symbol
)
return
self
.
frequencies
[
symbol
]
# Sets the frequency of the given symbol to the given value. The frequency value
# must be at least 0. If an exception is raised, then the state is left unchanged.
def
set
(
self
,
symbol
,
freq
):
self
.
_check_symbol
(
symbol
)
if
freq
<
0
:
raise
ValueError
(
"Negative frequency"
)
temp
=
self
.
total
-
self
.
frequencies
[
symbol
]
assert
temp
>=
0
self
.
total
=
temp
+
freq
self
.
frequencies
[
symbol
]
=
freq
self
.
cumulative
=
None
# Increments the frequency of the given symbol.
def
increment
(
self
,
symbol
):
self
.
_check_symbol
(
symbol
)
self
.
total
+=
1
self
.
frequencies
[
symbol
]
+=
1
self
.
cumulative
=
None
# Returns the total of all symbol frequencies. The returned value is at
# least 0 and is always equal to get_high(get_symbol_limit() - 1).
def
get_total
(
self
):
return
self
.
total
# Returns the sum of the frequencies of all the symbols strictly
# below the given symbol value. The returned value is at least 0.
def
get_low
(
self
,
symbol
):
self
.
_check_symbol
(
symbol
)
if
self
.
cumulative
is
None
:
self
.
_init_cumulative
()
return
self
.
cumulative
[
symbol
]
# Returns the sum of the frequencies of the given symbol
# and all the symbols below. The returned value is at least 0.
def
get_high
(
self
,
symbol
):
self
.
_check_symbol
(
symbol
)
if
self
.
cumulative
is
None
:
self
.
_init_cumulative
()
return
self
.
cumulative
[
symbol
+
1
]
# Recomputes the array of cumulative symbol frequencies.
def
_init_cumulative
(
self
):
cumul
=
[
0
]
sum
=
0
for
freq
in
self
.
frequencies
:
sum
+=
freq
cumul
.
append
(
sum
)
assert
sum
==
self
.
total
self
.
cumulative
=
cumul
# Returns silently if 0 <= symbol < len(frequencies), otherwise raises an exception.
def
_check_symbol
(
self
,
symbol
):
if
0
<=
symbol
<
len
(
self
.
frequencies
):
return
else
:
raise
ValueError
(
"Symbol out of range"
)
# Returns a string representation of this frequency table,
# useful for debugging only, and the format is subject to change.
def
__str__
(
self
):
result
=
""
for
(
i
,
freq
)
in
enumerate
(
self
.
frequencies
):
result
+=
"{}
\t
{}
\n
"
.
format
(
i
,
freq
)
return
result
# A wrapper that checks the preconditions (arguments) and postconditions (return value) of all
# the frequency table methods. Useful for finding faults in a frequency table implementation.
class
CheckedFrequencyTable
(
FrequencyTable
):
def
__init__
(
self
,
freqtab
):
# The underlying frequency table that holds the data
self
.
freqtable
=
freqtab
def
get_symbol_limit
(
self
):
result
=
self
.
freqtable
.
get_symbol_limit
()
if
result
<=
0
:
raise
AssertionError
(
"Non-positive symbol limit"
)
return
result
def
get
(
self
,
symbol
):
result
=
self
.
freqtable
.
get
(
symbol
)
if
not
self
.
_is_symbol_in_range
(
symbol
):
raise
AssertionError
(
"ValueError expected"
)
if
result
<
0
:
raise
AssertionError
(
"Negative symbol frequency"
)
return
result
def
get_total
(
self
):
result
=
self
.
freqtable
.
get_total
()
if
result
<
0
:
raise
AssertionError
(
"Negative total frequency"
)
return
result
def
get_low
(
self
,
symbol
):
if
self
.
_is_symbol_in_range
(
symbol
):
low
=
self
.
freqtable
.
get_low
(
symbol
)
high
=
self
.
freqtable
.
get_high
(
symbol
)
if
not
(
0
<=
low
<=
high
<=
self
.
freqtable
.
get_total
()):
raise
AssertionError
(
"Symbol low cumulative frequency out of range"
)
return
low
else
:
self
.
freqtable
.
get_low
(
symbol
)
raise
AssertionError
(
"ValueError expected"
)
def
get_high
(
self
,
symbol
):
if
self
.
_is_symbol_in_range
(
symbol
):
low
=
self
.
freqtable
.
get_low
(
symbol
)
high
=
self
.
freqtable
.
get_high
(
symbol
)
if
not
(
0
<=
low
<=
high
<=
self
.
freqtable
.
get_total
()):
raise
AssertionError
(
"Symbol high cumulative frequency out of range"
)
return
high
else
:
self
.
freqtable
.
get_high
(
symbol
)
raise
AssertionError
(
"ValueError expected"
)
def
__str__
(
self
):
return
"CheckFrequencyTable ("
+
str
(
self
.
freqtable
)
+
")"
def
set
(
self
,
symbol
,
freq
):
self
.
freqtable
.
set
(
symbol
,
freq
)
if
not
self
.
_is_symbol_in_range
(
symbol
)
or
freq
<
0
:
raise
AssertionError
(
"ValueError expected"
)
def
increment
(
self
,
symbol
):
self
.
freqtable
.
increment
(
symbol
)
if
not
self
.
_is_symbol_in_range
(
symbol
):
raise
AssertionError
(
"ValueError expected"
)
def
_is_symbol_in_range
(
self
,
symbol
):
return
0
<=
symbol
<
self
.
get_symbol_limit
()
# ---- Bit-oriented I/O streams ----
# A stream of bits that can be read. Because they come from an underlying byte stream,
# the total number of bits is always a multiple of 8. The bits are read in big endian.
class
BitInputStream
(
object
):
# Constructs a bit input stream based on the given byte input stream.
def
__init__
(
self
,
inp
):
# The underlying byte stream to read from
self
.
input
=
inp
# Either in the range [0x00, 0xFF] if bits are available, or -1 if end of stream is reached
self
.
currentbyte
=
0
# Number of remaining bits in the current byte, always between 0 and 7 (inclusive)
self
.
numbitsremaining
=
0
# Reads a bit from this stream. Returns 0 or 1 if a bit is available, or -1 if
# the end of stream is reached. The end of stream always occurs on a byte boundary.
def
read
(
self
):
if
self
.
currentbyte
==
-
1
:
return
-
1
if
self
.
numbitsremaining
==
0
:
temp
=
self
.
input
.
read
(
1
)
if
len
(
temp
)
==
0
:
self
.
currentbyte
=
-
1
return
-
1
self
.
currentbyte
=
temp
[
0
]
if
python3
else
ord
(
temp
)
self
.
numbitsremaining
=
8
assert
self
.
numbitsremaining
>
0
self
.
numbitsremaining
-=
1
return
(
self
.
currentbyte
>>
self
.
numbitsremaining
)
&
1
# Reads a bit from this stream. Returns 0 or 1 if a bit is available, or raises an EOFError
# if the end of stream is reached. The end of stream always occurs on a byte boundary.
def
read_no_eof
(
self
):
result
=
self
.
read
()
if
result
!=
-
1
:
return
result
else
:
raise
EOFError
()
# Closes this stream and the underlying input stream.
def
close
(
self
):
self
.
input
.
close
()
self
.
currentbyte
=
-
1
self
.
numbitsremaining
=
0
# A stream where bits can be written to. Because they are written to an underlying
# byte stream, the end of the stream is padded with 0's up to a multiple of 8 bits.
# The bits are written in big endian.
class
BitOutputStream
(
object
):
# Constructs a bit output stream based on the given byte output stream.
def
__init__
(
self
,
out
):
self
.
output
=
out
# The underlying byte stream to write to
self
.
currentbyte
=
0
# The accumulated bits for the current byte, always in the range [0x00, 0xFF]
self
.
numbitsfilled
=
0
# Number of accumulated bits in the current byte, always between 0 and 7 (inclusive)
# Writes a bit to the stream. The given bit must be 0 or 1.
def
write
(
self
,
b
):
if
b
not
in
(
0
,
1
):
raise
ValueError
(
"Argument must be 0 or 1"
)
self
.
currentbyte
=
(
self
.
currentbyte
<<
1
)
|
b
self
.
numbitsfilled
+=
1
if
self
.
numbitsfilled
==
8
:
towrite
=
bytes
((
self
.
currentbyte
,))
if
python3
else
chr
(
self
.
currentbyte
)
self
.
output
.
write
(
towrite
)
self
.
currentbyte
=
0
self
.
numbitsfilled
=
0
# Closes this stream and the underlying output stream. If called when this
# bit stream is not at a byte boundary, then the minimum number of "0" bits
# (between 0 and 7 of them) are written as padding to reach the next byte boundary.
def
close
(
self
):
while
self
.
numbitsfilled
!=
0
:
self
.
write
(
0
)
self
.
output
.
close
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment