Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
I
Image Compression
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Nathaniel Callens
Image Compression
Commits
bbb3dc00
Commit
bbb3dc00
authored
Mar 08, 2022
by
Kelly Chang
Browse files
Options
Browse Files
Download
Plain Diff
Merge branch 'master' of
https://git.elphel.com/nathaniel/image-compression
parents
1fcec44c
269f249e
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
0 additions
and
822 deletions
+0
-822
adaptive_arithmetic_compress.py
adaptive_arithmetic_compress.py
+0
-57
arithmetic.py
arithmetic.py
+0
-167
arithmeticcoding.py
arithmeticcoding.py
+0
-598
No files found.
adaptive_arithmetic_compress.py
deleted
100644 → 0
View file @
1fcec44c
#
# Compression application using adaptive arithmetic coding
#
# Usage: python adaptive-arithmetic-compress.py InputFile OutputFile
# Then use the corresponding adaptive-arithmetic-decompress.py application to recreate the original input file.
# Note that the application starts with a flat frequency table of 257 symbols (all set to a frequency of 1),
# and updates it after each byte encoded. The corresponding decompressor program also starts with a flat
# frequency table and updates it after each byte decoded. It is by design that the compressor and
# decompressor have synchronized states, so that the data can be decompressed properly.
#
# Copyright (c) Project Nayuki
#
# https://www.nayuki.io/page/reference-arithmetic-coding
# https://github.com/nayuki/Reference-arithmetic-coding
#
import
sys
import
arithmeticcoding
python3
=
sys
.
version_info
.
major
>=
3
# Command line main application function.
def
main
(
args
):
# Handle command line arguments
if
len
(
args
)
!=
2
:
sys
.
exit
(
"Usage: python adaptive-arithmetic-compress.py InputFile OutputFile"
)
inputfile
=
args
[
0
]
outputfile
=
args
[
1
]
# Perform file compression
with
open
(
inputfile
,
"rb"
)
as
inp
:
bitout
=
arithmeticcoding
.
BitOutputStream
(
open
(
outputfile
,
"wb"
))
try
:
compress
(
inp
,
bitout
)
finally
:
bitout
.
close
()
def
compress
(
inp
,
bitout
):
initfreqs
=
arithmeticcoding
.
FlatFrequencyTable
(
257
)
freqs
=
arithmeticcoding
.
SimpleFrequencyTable
(
initfreqs
)
enc
=
arithmeticcoding
.
ArithmeticEncoder
(
bitout
)
while
True
:
# Read and encode one byte
symbol
=
inp
.
read
(
1
)
if
len
(
symbol
)
==
0
:
break
symbol
=
symbol
[
0
]
if
python3
else
ord
(
symbol
)
enc
.
write
(
freqs
,
symbol
)
freqs
.
increment
(
symbol
)
enc
.
write
(
freqs
,
256
)
# EOF
enc
.
finish
()
# Flush remaining code bits
# Main launcher
if
__name__
==
"__main__"
:
main
(
sys
.
argv
[
1
:
])
arithmetic.py
deleted
100644 → 0
View file @
1fcec44c
# Example implementation of simple arithmetic coding in Python (2.7+).
#
# USAGE
#
# python -i arithmetic.py
# >>> m = {'a': 1, 'b': 1, 'c': 1}
# >>> model = dirichlet(m)
# >>> encode(model, "aabbaacc")
# '00011110011110010'
#
# NOTES
#
# This implementation has many shortcomings, e.g.,
# - There are several inefficient tests, loops, and conversions
# - There are a few places where code is uncessarily duplicated
# - It does not output the coded message as a stream
# - It can only code short messages due to machine precision
# - The is no defensive coding against errors (e.g., out-of-model symbols)
# - I've not implemented a decoder!
#
# The aim was to make the implementation here as close as possible to
# the algorithm described in lectures while giving some extra detail about
# routines such as finding extensions to binary intervals.
#
# For a more sophisticated implementation, please refer to:
#
# "Arithmetic Coding for Data Compression"
# I. H. Witten, R. M. Neal, and J. G. Cleary
# Communications of the ACM, Col. 30 (6), 1987
#
# AUTHOR: Mark Reid
# CREATED: 2014-09-30
def
encode
(
G
,
stream
):
'''
Arithmetically encodes the given stream using the guesser function G
which returns probabilities over symbols P(x|xs) given a sequence xs.
'''
u
,
v
=
0.0
,
1.0
# The interval [u, v) for the message
xs
,
bs
=
""
,
""
# The message xs, and binary code bs
p
=
G
(
xs
)
# Compute the initial distribution over symbols
# Iterate through stream, repeatedly finding the longest binary code
# that surrounds the interval for the message so far
for
x
in
stream
:
# Record the new symbol
xs
+=
x
# Find the interval for the message so far
F_lo
,
F_hi
=
cdf_interval
(
p
,
x
)
u
,
v
=
u
+
(
v
-
u
)
*
F_lo
,
u
+
(
v
-
u
)
*
F_hi
# Find a binary code whose interval surrounds [u,v)
bs
=
extend_around
(
bs
,
u
,
v
)
# Update the symbol probabilities
p
=
G
(
xs
)
# Stream finished so find shortest extension of the code that sits inside
# the top half of [u, v)
bs
=
extend_inside
(
bs
,
u
+
(
v
-
u
)
/
2
,
v
)
return
bs
##############################################################################
# Models
def
dirichlet
(
m
):
'''
Returns a Dirichlet model (as a function) for probabilities with
prior counts given by the symbol to count dictionary m.
Probabilities returned by the returned functions are (symbol, prob)
dictionaries.
'''
# Build a function that returns P(x|xs) based on the priors in m
# and the counts of the symbols in xs
def
p
(
xs
):
counts
=
m
.
copy
()
for
x
in
xs
:
counts
[
x
]
+=
1
total
=
sum
(
counts
.
values
())
return
{
a
:
float
(
c
)
/
total
for
a
,
c
in
counts
.
items
()
}
# Return the constructed function
return
p
##############################################################################
# Interval methods
def
cdf_interval
(
p
,
a
):
'''
Compute the cumulative distribution interval [F(a'), F(a)) for the
probabilities p (represented as a (symbol,prob) dict) where
F(a) = P(x <= a) and a' is the symbol preceeding a.
'''
F_lo
,
F_hi
=
0
,
0
A
=
sorted
(
p
)
for
x
in
A
:
F_lo
,
F_hi
=
F_hi
,
F_hi
+
p
[
x
]
if
x
==
a
:
break
return
F_lo
,
F_hi
def
binary_interval
(
bs
):
'''
Returns an interval [n, m) for n and m integers, and denominator d
representing the interval [n/d, m/d) for the binary string bs.
'''
n
,
d
=
to_rational
(
bs
)
return
n
,
n
+
1
,
d
def
to_rational
(
bs
):
'''Return numerator and denominator for ratio of 0.bs.'''
n
=
0
for
b
in
bs
:
n
*=
2
n
+=
int
(
b
)
return
n
,
2
**
len
(
bs
)
def
around
(
bs
,
u
,
v
):
'''Tests whether [0.bs, 0.bs111...) contains [u, v).'''
n
,
m
,
d
=
binary_interval
(
bs
)
return
(
n
<=
u
*
d
)
and
(
v
*
d
<=
m
)
def
extend_around
(
bs
,
u
,
v
):
'''Find the longest extension of the given binary string so its interval
wraps around the interval [u, v).'''
contained
=
True
while
contained
:
if
around
(
bs
+
"0"
,
u
,
v
):
bs
+=
"0"
elif
around
(
bs
+
"1"
,
u
,
v
):
bs
+=
"1"
else
:
contained
=
False
return
bs
def
inside
(
bs
,
u
,
v
):
'''Tests whether [0.bs, 0.bs111...) is contained by [u, v).'''
n
,
m
,
d
=
binary_interval
(
bs
)
return
(
u
*
d
<=
n
)
and
(
m
<=
v
*
d
)
def
extend_inside
(
bs
,
u
,
v
):
'''Find the shortest extension of the given binary string so its interval
sits inside the interval [u, v).'''
while
not
inside
(
bs
,
u
,
v
):
# Test whether gap between binary interval and [u,v) is bigger at the
# bottom than at the top
n
,
m
,
d
=
binary_interval
(
bs
)
if
u
*
d
-
n
>
m
-
v
*
d
:
bs
+=
"1"
# If so, move bottom up by halving
else
:
bs
+=
"0"
# If not, move top down by halving
return
bs
arithmeticcoding.py
deleted
100644 → 0
View file @
1fcec44c
#
# Reference arithmetic coding
# Copyright (c) Project Nayuki
#
# https://www.nayuki.io/page/reference-arithmetic-coding
# https://github.com/nayuki/Reference-arithmetic-coding
#
import
sys
python3
=
sys
.
version_info
.
major
>=
3
# ---- Arithmetic coding core classes ----
# Provides the state and behaviors that arithmetic coding encoders and decoders share.
class
ArithmeticCoderBase
(
object
):
# Number of bits for 'low' and 'high'. Configurable and must be at least 1.
STATE_SIZE
=
32
# Maximum range during coding (trivial), i.e. 1000...000.
MAX_RANGE
=
1
<<
STATE_SIZE
# Minimum range during coding (non-trivial), i.e. 010...010.
MIN_RANGE
=
(
MAX_RANGE
>>
2
)
+
2
# Maximum allowed total frequency at all times during coding.
MAX_TOTAL
=
MIN_RANGE
# Mask of STATE_SIZE ones, i.e. 111...111.
MASK
=
MAX_RANGE
-
1
# Mask of the top bit at width STATE_SIZE, i.e. 100...000.
TOP_MASK
=
MAX_RANGE
>>
1
# Mask of the second highest bit at width STATE_SIZE, i.e. 010...000.
SECOND_MASK
=
TOP_MASK
>>
1
# Constructs an arithmetic coder, which initializes the code range.
def
__init__
(
self
):
# Low end of this arithmetic coder's current range. Conceptually has an infinite number of trailing 0s.
self
.
low
=
0
# High end of this arithmetic coder's current range. Conceptually has an infinite number of trailing 1s.
self
.
high
=
ArithmeticCoderBase
.
MASK
# Updates the code range (low and high) of this arithmetic coder as a result
# of processing the given symbol with the given frequency table.
# Invariants that are true before and after encoding/decoding each symbol:
# - 0 <= low <= code <= high < 2^STATE_SIZE. ('code' exists only in the decoder.)
# Therefore these variables are unsigned integers of STATE_SIZE bits.
# - (low < 1/2 * 2^STATE_SIZE) && (high >= 1/2 * 2^STATE_SIZE).
# In other words, they are in different halves of the full range.
# - (low < 1/4 * 2^STATE_SIZE) || (high >= 3/4 * 2^STATE_SIZE).
# In other words, they are not both in the middle two quarters.
# - Let range = high - low + 1, then MAX_RANGE/4 < MIN_RANGE <= range
# <= MAX_RANGE = 2^STATE_SIZE. These invariants for 'range' essentially
# dictate the maximum total that the incoming frequency table can have.
def
update
(
self
,
freqs
,
symbol
):
# State check
low
=
self
.
low
high
=
self
.
high
if
low
>=
high
or
(
low
&
ArithmeticCoderBase
.
MASK
)
!=
low
or
(
high
&
ArithmeticCoderBase
.
MASK
)
!=
high
:
raise
AssertionError
(
"Low or high out of range"
)
range
=
high
-
low
+
1
if
not
(
ArithmeticCoderBase
.
MIN_RANGE
<=
range
<=
ArithmeticCoderBase
.
MAX_RANGE
):
raise
AssertionError
(
"Range out of range"
)
# Frequency table values check
total
=
freqs
.
get_total
()
symlow
=
freqs
.
get_low
(
symbol
)
symhigh
=
freqs
.
get_high
(
symbol
)
if
symlow
==
symhigh
:
raise
ValueError
(
"Symbol has zero frequency"
)
if
total
>
ArithmeticCoderBase
.
MAX_TOTAL
:
raise
ValueError
(
"Cannot code symbol because total is too large"
)
# Update range
newlow
=
low
+
symlow
*
range
//
total
newhigh
=
low
+
symhigh
*
range
//
total
-
1
self
.
low
=
newlow
self
.
high
=
newhigh
# While the highest bits are equal
while
((
self
.
low
^
self
.
high
)
&
ArithmeticCoderBase
.
TOP_MASK
)
==
0
:
self
.
shift
()
self
.
low
=
(
self
.
low
<<
1
)
&
ArithmeticCoderBase
.
MASK
self
.
high
=
((
self
.
high
<<
1
)
&
ArithmeticCoderBase
.
MASK
)
|
1
# While the second highest bit of low is 1 and the second highest bit of high is 0
while
(
self
.
low
&
~
self
.
high
&
ArithmeticCoderBase
.
SECOND_MASK
)
!=
0
:
self
.
underflow
()
self
.
low
=
(
self
.
low
<<
1
)
&
(
ArithmeticCoderBase
.
MASK
>>
1
)
self
.
high
=
((
self
.
high
<<
1
)
&
(
ArithmeticCoderBase
.
MASK
>>
1
))
|
ArithmeticCoderBase
.
TOP_MASK
|
1
# Called to handle the situation when the top bit of 'low' and 'high' are equal.
def
shift
(
self
):
raise
NotImplementedError
()
# Called to handle the situation when low=01(...) and high=10(...).
def
underflow
(
self
):
raise
NotImplementedError
()
# Encodes symbols and writes to an arithmetic-coded bit stream.
class
ArithmeticEncoder
(
ArithmeticCoderBase
):
# Constructs an arithmetic coding encoder based on the given bit output stream.
def
__init__
(
self
,
bitout
):
super
(
ArithmeticEncoder
,
self
)
.
__init__
()
# The underlying bit output stream.
self
.
output
=
bitout
# Number of saved underflow bits. This value can grow without bound.
self
.
num_underflow
=
0
# Encodes the given symbol based on the given frequency table.
# This updates this arithmetic coder's state and may write out some bits.
def
write
(
self
,
freqs
,
symbol
):
if
not
isinstance
(
freqs
,
CheckedFrequencyTable
):
freqs
=
CheckedFrequencyTable
(
freqs
)
self
.
update
(
freqs
,
symbol
)
# Terminates the arithmetic coding by flushing any buffered bits, so that the output can be decoded properly.
# It is important that this method must be called at the end of the each encoding process.
# Note that this method merely writes data to the underlying output stream but does not close it.
def
finish
(
self
):
self
.
output
.
write
(
1
)
def
shift
(
self
):
bit
=
self
.
low
>>
(
ArithmeticCoderBase
.
STATE_SIZE
-
1
)
self
.
output
.
write
(
bit
)
# Write out the saved underflow bits
for
i
in
range
(
self
.
num_underflow
):
self
.
output
.
write
(
bit
^
1
)
self
.
num_underflow
=
0
def
underflow
(
self
):
self
.
num_underflow
+=
1
# Reads from an arithmetic-coded bit stream and decodes symbols.
class
ArithmeticDecoder
(
ArithmeticCoderBase
):
# Constructs an arithmetic coding decoder based on the
# given bit input stream, and fills the code bits.
def
__init__
(
self
,
bitin
):
super
(
ArithmeticDecoder
,
self
)
.
__init__
()
# The underlying bit input stream.
self
.
input
=
bitin
# The current raw code bits being buffered, which is always in the range [low, high].
self
.
code
=
0
for
i
in
range
(
ArithmeticCoderBase
.
STATE_SIZE
):
self
.
code
=
self
.
code
<<
1
|
self
.
read_code_bit
()
# Decodes the next symbol based on the given frequency table and returns it.
# Also updates this arithmetic coder's state and may read in some bits.
def
read
(
self
,
freqs
):
if
not
isinstance
(
freqs
,
CheckedFrequencyTable
):
freqs
=
CheckedFrequencyTable
(
freqs
)
# Translate from coding range scale to frequency table scale
total
=
freqs
.
get_total
()
if
total
>
ArithmeticCoderBase
.
MAX_TOTAL
:
raise
ValueError
(
"Cannot decode symbol because total is too large"
)
range
=
self
.
high
-
self
.
low
+
1
offset
=
self
.
code
-
self
.
low
value
=
((
offset
+
1
)
*
total
-
1
)
//
range
assert
value
*
range
//
total
<=
offset
assert
0
<=
value
<
total
# A kind of binary search. Find highest symbol such that freqs.get_low(symbol) <= value.
start
=
0
end
=
freqs
.
get_symbol_limit
()
while
end
-
start
>
1
:
middle
=
(
start
+
end
)
>>
1
if
freqs
.
get_low
(
middle
)
>
value
:
end
=
middle
else
:
start
=
middle
assert
start
+
1
==
end
symbol
=
start
assert
freqs
.
get_low
(
symbol
)
*
range
//
total
<=
offset
<
freqs
.
get_high
(
symbol
)
*
range
//
total
self
.
update
(
freqs
,
symbol
)
if
not
(
self
.
low
<=
self
.
code
<=
self
.
high
):
raise
AssertionError
(
"Code out of range"
)
return
symbol
def
shift
(
self
):
self
.
code
=
((
self
.
code
<<
1
)
&
ArithmeticCoderBase
.
MASK
)
|
self
.
read_code_bit
()
def
underflow
(
self
):
self
.
code
=
(
self
.
code
&
ArithmeticCoderBase
.
TOP_MASK
)
|
((
self
.
code
<<
1
)
&
(
ArithmeticCoderBase
.
MASK
>>
1
))
|
self
.
read_code_bit
()
# Returns the next bit (0 or 1) from the input stream. The end
# of stream is treated as an infinite number of trailing zeros.
def
read_code_bit
(
self
):
temp
=
self
.
input
.
read
()
if
temp
==
-
1
:
temp
=
0
return
temp
# ---- Frequency table classes ----
# A table of symbol frequencies. The table holds data for symbols numbered from 0
# to get_symbol_limit()-1. Each symbol has a frequency, which is a non-negative integer.
# Frequency table objects are primarily used for getting cumulative symbol
# frequencies. These objects can be mutable depending on the implementation.
class
FrequencyTable
(
object
):
# Returns the number of symbols in this frequency table, which is a positive number.
def
get_symbol_limit
(
self
):
raise
NotImplementedError
()
# Returns the frequency of the given symbol. The returned value is at least 0.
def
get
(
self
,
symbol
):
raise
NotImplementedError
()
# Sets the frequency of the given symbol to the given value.
# The frequency value must be at least 0.
def
set
(
self
,
symbol
,
freq
):
raise
NotImplementedError
()
# Increments the frequency of the given symbol.
def
increment
(
self
,
symbol
):
raise
NotImplementedError
()
# Returns the total of all symbol frequencies. The returned value is at
# least 0 and is always equal to get_high(get_symbol_limit() - 1).
def
get_total
(
self
):
raise
NotImplementedError
()
# Returns the sum of the frequencies of all the symbols strictly
# below the given symbol value. The returned value is at least 0.
def
get_low
(
self
,
symbol
):
raise
NotImplementedError
()
# Returns the sum of the frequencies of the given symbol
# and all the symbols below. The returned value is at least 0.
def
get_high
(
self
,
symbol
):
raise
NotImplementedError
()
# An immutable frequency table where every symbol has the same frequency of 1.
# Useful as a fallback model when no statistics are available.
class
FlatFrequencyTable
(
FrequencyTable
):
# Constructs a flat frequency table with the given number of symbols.
def
__init__
(
self
,
numsyms
):
if
numsyms
<
1
:
raise
ValueError
(
"Number of symbols must be positive"
)
self
.
numsymbols
=
numsyms
# Total number of symbols, which is at least 1
# Returns the number of symbols in this table, which is at least 1.
def
get_symbol_limit
(
self
):
return
self
.
numsymbols
# Returns the frequency of the given symbol, which is always 1.
def
get
(
self
,
symbol
):
self
.
_check_symbol
(
symbol
)
return
1
# Returns the total of all symbol frequencies, which is
# always equal to the number of symbols in this table.
def
get_total
(
self
):
return
self
.
numsymbols
# Returns the sum of the frequencies of all the symbols strictly below
# the given symbol value. The returned value is equal to 'symbol'.
def
get_low
(
self
,
symbol
):
self
.
_check_symbol
(
symbol
)
return
symbol
# Returns the sum of the frequencies of the given symbol and all
# the symbols below. The returned value is equal to 'symbol' + 1.
def
get_high
(
self
,
symbol
):
self
.
_check_symbol
(
symbol
)
return
symbol
+
1
# Returns silently if 0 <= symbol < numsymbols, otherwise raises an exception.
def
_check_symbol
(
self
,
symbol
):
if
0
<=
symbol
<
self
.
numsymbols
:
return
else
:
raise
ValueError
(
"Symbol out of range"
)
# Returns a string representation of this frequency table. The format is subject to change.
def
__str__
(
self
):
return
"FlatFrequencyTable={}"
.
format
(
self
.
numsymbols
)
# Unsupported operation, because this frequency table is immutable.
def
set
(
self
,
symbol
,
freq
):
raise
NotImplementedError
()
# Unsupported operation, because this frequency table is immutable.
def
increment
(
self
,
symbol
):
raise
NotImplementedError
()
# A mutable table of symbol frequencies. The number of symbols cannot be changed
# after construction. The current algorithm for calculating cumulative frequencies
# takes linear time, but there exist faster algorithms such as Fenwick trees.
class
SimpleFrequencyTable
(
FrequencyTable
):
# Constructs a simple frequency table in one of two ways:
# - SimpleFrequencyTable(sequence):
# Builds a frequency table from the given sequence of symbol frequencies.
# There must be at least 1 symbol, and no symbol has a negative frequency.
# - SimpleFrequencyTable(freqtable):
# Builds a frequency table by copying the given frequency table.
def
__init__
(
self
,
freqs
):
if
isinstance
(
freqs
,
FrequencyTable
):
numsym
=
freqs
.
get_symbol_limit
()
self
.
frequencies
=
[
freqs
.
get
(
i
)
for
i
in
range
(
numsym
)]
else
:
# Assume it is a sequence type
self
.
frequencies
=
list
(
freqs
)
# Make copy
# 'frequencies' is a list of the frequency for each symbol.
# Its length is at least 1, and each element is non-negative.
if
len
(
self
.
frequencies
)
<
1
:
raise
ValueError
(
"At least 1 symbol needed"
)
for
freq
in
self
.
frequencies
:
if
freq
<
0
:
raise
ValueError
(
"Negative frequency"
)
# Always equal to the sum of 'frequencies'
self
.
total
=
sum
(
self
.
frequencies
)
# cumulative[i] is the sum of 'frequencies' from 0 (inclusive) to i (exclusive).
# Initialized lazily. When it is not None, the data is valid.
self
.
cumulative
=
None
# Returns the number of symbols in this frequency table, which is at least 1.
def
get_symbol_limit
(
self
):
return
len
(
self
.
frequencies
)
# Returns the frequency of the given symbol. The returned value is at least 0.
def
get
(
self
,
symbol
):
self
.
_check_symbol
(
symbol
)
return
self
.
frequencies
[
symbol
]
# Sets the frequency of the given symbol to the given value. The frequency value
# must be at least 0. If an exception is raised, then the state is left unchanged.
def
set
(
self
,
symbol
,
freq
):
self
.
_check_symbol
(
symbol
)
if
freq
<
0
:
raise
ValueError
(
"Negative frequency"
)
temp
=
self
.
total
-
self
.
frequencies
[
symbol
]
assert
temp
>=
0
self
.
total
=
temp
+
freq
self
.
frequencies
[
symbol
]
=
freq
self
.
cumulative
=
None
# Increments the frequency of the given symbol.
def
increment
(
self
,
symbol
):
self
.
_check_symbol
(
symbol
)
self
.
total
+=
1
self
.
frequencies
[
symbol
]
+=
1
self
.
cumulative
=
None
# Returns the total of all symbol frequencies. The returned value is at
# least 0 and is always equal to get_high(get_symbol_limit() - 1).
def
get_total
(
self
):
return
self
.
total
# Returns the sum of the frequencies of all the symbols strictly
# below the given symbol value. The returned value is at least 0.
def
get_low
(
self
,
symbol
):
self
.
_check_symbol
(
symbol
)
if
self
.
cumulative
is
None
:
self
.
_init_cumulative
()
return
self
.
cumulative
[
symbol
]
# Returns the sum of the frequencies of the given symbol
# and all the symbols below. The returned value is at least 0.
def
get_high
(
self
,
symbol
):
self
.
_check_symbol
(
symbol
)
if
self
.
cumulative
is
None
:
self
.
_init_cumulative
()
return
self
.
cumulative
[
symbol
+
1
]
# Recomputes the array of cumulative symbol frequencies.
def
_init_cumulative
(
self
):
cumul
=
[
0
]
sum
=
0
for
freq
in
self
.
frequencies
:
sum
+=
freq
cumul
.
append
(
sum
)
assert
sum
==
self
.
total
self
.
cumulative
=
cumul
# Returns silently if 0 <= symbol < len(frequencies), otherwise raises an exception.
def
_check_symbol
(
self
,
symbol
):
if
0
<=
symbol
<
len
(
self
.
frequencies
):
return
else
:
raise
ValueError
(
"Symbol out of range"
)
# Returns a string representation of this frequency table,
# useful for debugging only, and the format is subject to change.
def
__str__
(
self
):
result
=
""
for
(
i
,
freq
)
in
enumerate
(
self
.
frequencies
):
result
+=
"{}
\t
{}
\n
"
.
format
(
i
,
freq
)
return
result
# A wrapper that checks the preconditions (arguments) and postconditions (return value) of all
# the frequency table methods. Useful for finding faults in a frequency table implementation.
class
CheckedFrequencyTable
(
FrequencyTable
):
def
__init__
(
self
,
freqtab
):
# The underlying frequency table that holds the data
self
.
freqtable
=
freqtab
def
get_symbol_limit
(
self
):
result
=
self
.
freqtable
.
get_symbol_limit
()
if
result
<=
0
:
raise
AssertionError
(
"Non-positive symbol limit"
)
return
result
def
get
(
self
,
symbol
):
result
=
self
.
freqtable
.
get
(
symbol
)
if
not
self
.
_is_symbol_in_range
(
symbol
):
raise
AssertionError
(
"ValueError expected"
)
if
result
<
0
:
raise
AssertionError
(
"Negative symbol frequency"
)
return
result
def
get_total
(
self
):
result
=
self
.
freqtable
.
get_total
()
if
result
<
0
:
raise
AssertionError
(
"Negative total frequency"
)
return
result
def
get_low
(
self
,
symbol
):
if
self
.
_is_symbol_in_range
(
symbol
):
low
=
self
.
freqtable
.
get_low
(
symbol
)
high
=
self
.
freqtable
.
get_high
(
symbol
)
if
not
(
0
<=
low
<=
high
<=
self
.
freqtable
.
get_total
()):
raise
AssertionError
(
"Symbol low cumulative frequency out of range"
)
return
low
else
:
self
.
freqtable
.
get_low
(
symbol
)
raise
AssertionError
(
"ValueError expected"
)
def
get_high
(
self
,
symbol
):
if
self
.
_is_symbol_in_range
(
symbol
):
low
=
self
.
freqtable
.
get_low
(
symbol
)
high
=
self
.
freqtable
.
get_high
(
symbol
)
if
not
(
0
<=
low
<=
high
<=
self
.
freqtable
.
get_total
()):
raise
AssertionError
(
"Symbol high cumulative frequency out of range"
)
return
high
else
:
self
.
freqtable
.
get_high
(
symbol
)
raise
AssertionError
(
"ValueError expected"
)
def
__str__
(
self
):
return
"CheckFrequencyTable ("
+
str
(
self
.
freqtable
)
+
")"
def
set
(
self
,
symbol
,
freq
):
self
.
freqtable
.
set
(
symbol
,
freq
)
if
not
self
.
_is_symbol_in_range
(
symbol
)
or
freq
<
0
:
raise
AssertionError
(
"ValueError expected"
)
def
increment
(
self
,
symbol
):
self
.
freqtable
.
increment
(
symbol
)
if
not
self
.
_is_symbol_in_range
(
symbol
):
raise
AssertionError
(
"ValueError expected"
)
def
_is_symbol_in_range
(
self
,
symbol
):
return
0
<=
symbol
<
self
.
get_symbol_limit
()
# ---- Bit-oriented I/O streams ----
# A stream of bits that can be read. Because they come from an underlying byte stream,
# the total number of bits is always a multiple of 8. The bits are read in big endian.
class
BitInputStream
(
object
):
# Constructs a bit input stream based on the given byte input stream.
def
__init__
(
self
,
inp
):
# The underlying byte stream to read from
self
.
input
=
inp
# Either in the range [0x00, 0xFF] if bits are available, or -1 if end of stream is reached
self
.
currentbyte
=
0
# Number of remaining bits in the current byte, always between 0 and 7 (inclusive)
self
.
numbitsremaining
=
0
# Reads a bit from this stream. Returns 0 or 1 if a bit is available, or -1 if
# the end of stream is reached. The end of stream always occurs on a byte boundary.
def
read
(
self
):
if
self
.
currentbyte
==
-
1
:
return
-
1
if
self
.
numbitsremaining
==
0
:
temp
=
self
.
input
.
read
(
1
)
if
len
(
temp
)
==
0
:
self
.
currentbyte
=
-
1
return
-
1
self
.
currentbyte
=
temp
[
0
]
if
python3
else
ord
(
temp
)
self
.
numbitsremaining
=
8
assert
self
.
numbitsremaining
>
0
self
.
numbitsremaining
-=
1
return
(
self
.
currentbyte
>>
self
.
numbitsremaining
)
&
1
# Reads a bit from this stream. Returns 0 or 1 if a bit is available, or raises an EOFError
# if the end of stream is reached. The end of stream always occurs on a byte boundary.
def
read_no_eof
(
self
):
result
=
self
.
read
()
if
result
!=
-
1
:
return
result
else
:
raise
EOFError
()
# Closes this stream and the underlying input stream.
def
close
(
self
):
self
.
input
.
close
()
self
.
currentbyte
=
-
1
self
.
numbitsremaining
=
0
# A stream where bits can be written to. Because they are written to an underlying
# byte stream, the end of the stream is padded with 0's up to a multiple of 8 bits.
# The bits are written in big endian.
class
BitOutputStream
(
object
):
# Constructs a bit output stream based on the given byte output stream.
def
__init__
(
self
,
out
):
self
.
output
=
out
# The underlying byte stream to write to
self
.
currentbyte
=
0
# The accumulated bits for the current byte, always in the range [0x00, 0xFF]
self
.
numbitsfilled
=
0
# Number of accumulated bits in the current byte, always between 0 and 7 (inclusive)
# Writes a bit to the stream. The given bit must be 0 or 1.
def
write
(
self
,
b
):
if
b
not
in
(
0
,
1
):
raise
ValueError
(
"Argument must be 0 or 1"
)
self
.
currentbyte
=
(
self
.
currentbyte
<<
1
)
|
b
self
.
numbitsfilled
+=
1
if
self
.
numbitsfilled
==
8
:
towrite
=
bytes
((
self
.
currentbyte
,))
if
python3
else
chr
(
self
.
currentbyte
)
self
.
output
.
write
(
towrite
)
self
.
currentbyte
=
0
self
.
numbitsfilled
=
0
# Closes this stream and the underlying output stream. If called when this
# bit stream is not at a byte boundary, then the minimum number of "0" bits
# (between 0 and 7 of them) are written as padding to reach the next byte boundary.
def
close
(
self
):
while
self
.
numbitsfilled
!=
0
:
self
.
write
(
0
)
self
.
output
.
close
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment