Skip to content
Snippets Groups Projects
Verified Commit 6b7c96bd authored by Tamas Gal's avatar Tamas Gal :speech_balloon:
Browse files

Add logger

parent b2b0b6b2
No related branches found
No related tags found
No related merge requests found
from controlhost import Client from controlhost import Client
with Client('127.0.0.1') as client: with Client("127.0.0.1") as client:
client.subscribe('foo') client.subscribe("foo")
try: try:
while True: while True:
prefix, message = client.get_message() prefix, message = client.get_message()
print prefix.tag print(prefix.tag)
print prefix.length print(prefix.length)
print message print(message)
except KeyboardInterrupt: except KeyboardInterrupt:
client._disconnect() client._disconnect()
...@@ -3,18 +3,17 @@ import socket ...@@ -3,18 +3,17 @@ import socket
import controlhost as ch import controlhost as ch
s = socket.socket() s = socket.socket()
host = '131.188.167.62' host = "131.188.167.62"
port = 5553 port = 5553
s.connect((host, port)) s.connect((host, port))
message = ch.Message('_Subscri', ' w foo') message = ch.Message("_Subscri", " w foo")
print(s.send(message.data)) print(s.send(message.data))
message = ch.Message('_Always') message = ch.Message("_Always")
print(s.send(message.data)) print(s.send(message.data))
print(s.recv(1024)) print(s.recv(1024))
s.close() s.close()
...@@ -3,13 +3,12 @@ import socket ...@@ -3,13 +3,12 @@ import socket
import controlhost as ch import controlhost as ch
s = socket.socket() s = socket.socket()
host = '131.188.167.62' host = "131.188.167.62"
port = 5553 port = 5553
s.connect((host, port)) s.connect((host, port))
message = ch.Message('foo', 'test') message = ch.Message("foo", "test")
print(s.send(message.data)) print(s.send(message.data))
s.close() s.close()
...@@ -8,8 +8,9 @@ from .__version__ import version ...@@ -8,8 +8,9 @@ from .__version__ import version
from .controlhost import Client, Tag, Message, Prefix from .controlhost import Client, Tag, Message, Prefix
__author__ = "Tamas Gal" __author__ = "Tamas Gal"
__copyright__ = ("Copyright 2014, Tamas Gal and the KM3NeT collaboration " __copyright__ = (
"(http://km3net.org)") "Copyright 2014, Tamas Gal and the KM3NeT collaboration " "(http://km3net.org)"
)
__credits__ = [] __credits__ = []
__license__ = "MIT" __license__ = "MIT"
__version__ = version __version__ = version
......
...@@ -10,23 +10,25 @@ Pep 386 compliant version info. ...@@ -10,23 +10,25 @@ Pep 386 compliant version info.
(1, 2, 0, 'beta', 2) => "1.2b2" (1, 2, 0, 'beta', 2) => "1.2b2"
""" """
version_info = (0, 7, 1, 'final', 0) version_info = (0, 7, 1, "final", 0)
def _get_version(version_info): def _get_version(version_info):
"""Return a PEP 386-compliant version number.""" """Return a PEP 386-compliant version number."""
assert len(version_info) == 5 assert len(version_info) == 5
assert version_info[3] in ('alpha', 'beta', 'rc', 'final') assert version_info[3] in ("alpha", "beta", "rc", "final")
parts = 2 if version_info[2] == 0 else 3 parts = 2 if version_info[2] == 0 else 3
main = '.'.join(map(str, version_info[:parts])) main = ".".join(map(str, version_info[:parts]))
sub = '' sub = ""
if version_info[3] == 'alpha' and version_info[4] == 0: if version_info[3] == "alpha" and version_info[4] == 0:
sub = '.dev' sub = ".dev"
elif version_info[3] != 'final': elif version_info[3] != "final":
mapping = {'alpha': 'a', 'beta': 'b', 'rc': 'c'} mapping = {"alpha": "a", "beta": "b", "rc": "c"}
sub = mapping[version_info[3]] + str(version_info[4]) sub = mapping[version_info[3]] + str(version_info[4])
return str(main + sub) return str(main + sub)
version = _get_version(version_info) version = _get_version(version_info)
...@@ -3,17 +3,11 @@ ...@@ -3,17 +3,11 @@
A set of classes and tools wich uses the ControlHost protocol. A set of classes and tools wich uses the ControlHost protocol.
""" """
from __future__ import absolute_import, print_function, division
from collections import namedtuple
import socket import socket
import struct import struct
import time import time
try: from .logger import get_logger
from km3pipe.logging import get_logger
except ImportError:
from logging import getLogger as get_logger
__author__ = "Tamas Gal" __author__ = "Tamas Gal"
__copyright__ = "Copyright 2016, Tamas Gal and the KM3NeT collaboration." __copyright__ = "Copyright 2016, Tamas Gal and the KM3NeT collaboration."
...@@ -27,27 +21,27 @@ log = get_logger(__name__) ...@@ -27,27 +21,27 @@ log = get_logger(__name__)
BUFFER_SIZE = 1024 BUFFER_SIZE = 1024
CHResponse = namedtuple("CHResponse", ["prefix", "data"])
class Client(object): class Client(object):
"""The ControlHost client""" """The ControlHost client"""
def __init__(self, host, port=5553, tag=None): _valid_modes = ["any", "all"]
def __init__(self, host, port=5553):
self.host = host self.host = host
self.port = port self.port = port
self.socket = None self.socket = None
self.tags = [] self.tags = []
self.valid_tags = [] self.valid_tags = []
if tag is not None: def subscribe(self, tag, mode="any"):
self._connect() if mode not in self._valid_modes:
self.subscribe(tag) raise ValueError(
"Possible subscription modes are: {}".format(
def subscribe(self, tag, mode='wait'): ", ".join(self._valid_modes)
if mode not in ['wait', 'all']: )
raise ValueError("Possible subscription modes are 'wait' or 'all'") )
log.info("Subscribing to {} in mode {}".format(tag, mode)) log.info("Subscribing to %s in mode %s", tag, mode)
full_tag = self._full_tag(tag, mode) full_tag = self._full_tag(tag, mode)
if full_tag not in self.tags: if full_tag not in self.tags:
self.tags.append(full_tag) self.tags.append(full_tag)
...@@ -56,7 +50,7 @@ class Client(object): ...@@ -56,7 +50,7 @@ class Client(object):
self.valid_tags.append(t) self.valid_tags.append(t)
self._update_subscriptions() self._update_subscriptions()
def unsubscribe(self, tag, mode='wait'): def unsubscribe(self, tag, mode="any"):
try: try:
self.tags.remove(self._full_tag(tag, mode)) self.tags.remove(self._full_tag(tag, mode))
self.valid_tags.remove(tag) self.valid_tags.remove(tag)
...@@ -66,26 +60,39 @@ class Client(object): ...@@ -66,26 +60,39 @@ class Client(object):
self._update_subscriptions() self._update_subscriptions()
def _full_tag(self, tag, mode): def _full_tag(self, tag, mode):
mode_flag = ' {} '.format(mode[0]) mode_flag = " {} ".format("w" if mode == "any" else "a")
full_tag = mode_flag + tag full_tag = mode_flag + tag
return full_tag return full_tag
def _update_subscriptions(self): def _update_subscriptions(self):
log.debug("Subscribing to tags: {0}".format(self.tags)) log.debug("Subscribing to tags: %s", self.tags)
tags = ''.join(self.tags).encode("ascii") if not self.socket:
message = Message(b'_Subscri', tags) self._connect()
tags = "".join(self.tags).encode("ascii")
message = Message(b"_Subscri", tags)
self.socket.send(message.data) self.socket.send(message.data)
message = Message(b'_Always') message = Message(b"_Always")
self.socket.send(message.data) self.socket.send(message.data)
def put_message(self, tag, data):
"""Send data to the ligier with a given tag"""
if not self.socket:
self._connect()
msg = Message(tag, data)
self.socket.send(msg.data)
def get_message(self): def get_message(self):
while True: while True:
log.info(" Waiting for control host Prefix") log.info(" Waiting for control host Prefix")
if self.socket is None:
log.error("Lost socket connection, reconnecting...")
self._reconnect()
continue
try: try:
data = self.socket.recv(Prefix.SIZE) data = self._recv(Prefix.SIZE)
timestamp = time.time() timestamp = time.time()
log.info(" raw prefix data received: '{0}'".format(data)) log.info(" raw prefix data received: '%s'", data)
if data == b'': if data == b"":
raise EOFError raise EOFError
prefix = Prefix(data=data, timestamp=timestamp) prefix = Prefix(data=data, timestamp=timestamp)
except (UnicodeDecodeError, OSError, struct.error): except (UnicodeDecodeError, OSError, struct.error):
...@@ -102,40 +109,30 @@ class Client(object): ...@@ -102,40 +109,30 @@ class Client(object):
if prefix_tag not in self.valid_tags: if prefix_tag not in self.valid_tags:
log.error( log.error(
"Invalid tag '{0}' received, ignoring the message \n" "Invalid tag '%s' received, ignoring the message \n"
"and reconnecting.\n" "and reconnecting.\n"
" -> valid tags are: {1}".format( " -> valid tags are: %s",
prefix_tag, self.valid_tags prefix_tag,
) self.valid_tags,
) )
self._reconnect() self._reconnect()
continue continue
else: else:
break break
message = b'' log.info(" got a Prefix with %d bytes.", prefix.length)
log.info(" got a Prefix with {0} bytes.".format(prefix.length)) message = self._recv(prefix.length)
while len(message) < prefix.length: log.info(" ------ returning message with %d bytes", len(message))
log.info(" message length: {0}".format(len(message))) return prefix, message
log.info(" (getting next part)")
buffer_size = min((BUFFER_SIZE, (prefix.length - len(message))))
try:
message += self.socket.recv(buffer_size)
except OSError:
log.error("Failed to construct message.")
raise BufferError
log.info(
" ------ returning message with {0} bytes".format(
len(message)
)
)
return CHResponse(prefix, message)
def _connect(self): def _connect(self):
"""Connect to JLigier""" """Connect to JLigier"""
log.debug("Connecting to JLigier") log.debug("Connecting to JLigier")
self.socket = socket.socket()
self.socket.connect((self.host, self.port)) s = socket.socket()
s.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
s.connect((self.host, self.port))
self.socket = s
def _disconnect(self): def _disconnect(self):
"""Close the socket""" """Close the socket"""
...@@ -150,9 +147,25 @@ class Client(object): ...@@ -150,9 +147,25 @@ class Client(object):
self._connect() self._connect()
self._update_subscriptions() self._update_subscriptions()
def _recv(self, size):
"""Receive the exact amount of bytes from the socket.
This is needed since socket.recv(size) may return less than
size, according to the specification: 'size' is the maximum
number of bytes returned.
"""
message = b""
while len(message) < size:
buffer_size = min((BUFFER_SIZE, (size - len(message))))
try:
message += self.socket.recv(buffer_size)
except OSError as e:
log.error("Failed to receive controlhost message: %s", e)
raise BufferError
return message
def __enter__(self): def __enter__(self):
if not self.socket: self._connect()
self._connect()
return self return self
def __exit__(self, exit_type, value, traceback): def __exit__(self, exit_type, value, traceback):
...@@ -162,7 +175,15 @@ class Client(object): ...@@ -162,7 +175,15 @@ class Client(object):
class Message(object): class Message(object):
"""The representation of a ControlHost message.""" """The representation of a ControlHost message."""
def __init__(self, tag, message=b''): def __init__(self, tag, message=b""):
try:
message = message.encode()
except AttributeError:
pass
try:
tag = tag.encode()
except AttributeError:
pass
self.prefix = Prefix(tag, len(message)) self.prefix = Prefix(tag, len(message))
self.message = message self.message = message
...@@ -173,10 +194,11 @@ class Message(object): ...@@ -173,10 +194,11 @@ class Message(object):
class Tag(object): class Tag(object):
"""Represents the tag in a ControlHost Prefix.""" """Represents the tag in a ControlHost Prefix."""
SIZE = 8 SIZE = 8
def __init__(self, data=None): def __init__(self, data=None):
self._data = b'' self._data = b""
self.data = data self.data = data
@property @property
...@@ -188,15 +210,15 @@ class Tag(object): ...@@ -188,15 +210,15 @@ class Tag(object):
def data(self, value): def data(self, value):
"""Set the byte data and fill up the bytes to fit the size.""" """Set the byte data and fill up the bytes to fit the size."""
if not value: if not value:
value = b'' value = b""
if len(value) > self.SIZE: if len(value) > self.SIZE:
raise ValueError("The maximum tag size is {0}".format(self.SIZE)) raise ValueError("The maximum tag size is {}".format(self.SIZE))
self._data = value self._data = value
while len(self._data) < self.SIZE: while len(self._data) < self.SIZE:
self._data += b'\x00' self._data += b"\x00"
def __str__(self): def __str__(self):
return self.data.decode(encoding='UTF-8').strip('\x00') return self.data.decode(encoding="UTF-8").strip("\x00")
def __len__(self): def __len__(self):
return len(self._data) return len(self._data)
...@@ -204,6 +226,7 @@ class Tag(object): ...@@ -204,6 +226,7 @@ class Tag(object):
class Prefix(object): class Prefix(object):
"""The prefix of a ControlHost message.""" """The prefix of a ControlHost message."""
SIZE = 16 SIZE = 16
def __init__(self, tag=None, length=None, data=None, timestamp=None): def __init__(self, tag=None, length=None, data=None, timestamp=None):
...@@ -219,16 +242,14 @@ class Prefix(object): ...@@ -219,16 +242,14 @@ class Prefix(object):
@property @property
def data(self): def data(self):
return self.tag.data + struct.pack('>i', self.length) + b'\x00' * 4 return self.tag.data + struct.pack(">i", self.length) + b"\x00" * 4
@data.setter @data.setter
def data(self, value): def data(self, value):
self.tag = Tag(data=value[:Tag.SIZE]) self.tag = Tag(data=value[: Tag.SIZE])
self.length = struct.unpack('>i', value[Tag.SIZE:Tag.SIZE + 4])[0] self.length = struct.unpack(">i", value[Tag.SIZE : Tag.SIZE + 4])[0]
def __str__(self): def __str__(self):
return ( return "ControlHost Prefix with tag '{0}' ({1} bytes of data)".format(
"ControlHost Prefix with tag '{0}' ({1} bytes of data)".format( self.tag, self.length
self.tag, self.length
)
) )
# Filename: logger.py
# pylint: disable=locally-disabled,C0103
"""
The logging facility.
"""
from hashlib import sha256
from inspect import getframeinfo, stack
import socket
import sys
import logging
import logging.handlers
__author__ = "Tamas Gal"
__copyright__ = "Copyright 2016, Tamas Gal and the KM3NeT collaboration."
__credits__ = []
__license__ = "MIT"
__maintainer__ = "Tamas Gal"
__email__ = "tgal@km3net.de"
__status__ = "Development"
loggers = {} # this holds all the registered loggers
# logging.basicConfig()
DEPRECATION = 45
logging.addLevelName(DEPRECATION, "DEPRECATION")
ONCE = 46
logging.addLevelName(ONCE, "ONCE")
def supports_color():
"""Checks if the terminal supports color."""
if isnotebook():
return True
supported_platform = sys.platform != "win32" or "ANSICON" in os.environ
is_a_tty = hasattr(sys.stdout, "isatty") and sys.stdout.isatty()
if not supported_platform or not is_a_tty:
return False
return True
def isnotebook():
"""Check if running within a Jupyter notebook"""
try:
shell = get_ipython().__class__.__name__
if shell == "ZMQInteractiveShell":
return True # Jupyter notebook or qtconsole
elif shell == "TerminalInteractiveShell":
return False # Terminal running IPython
else:
return False # Other type (?)
except NameError:
return False
ATTRIBUTES = dict(
list(
zip(
["bold", "dark", "", "underline", "blink", "", "reverse", "concealed"],
list(range(1, 9)),
)
)
)
del ATTRIBUTES[""]
ATTRIBUTES_RE = r"\033\[(?:%s)m" % "|".join(["%d" % v for v in ATTRIBUTES.values()])
HIGHLIGHTS = dict(
list(
zip(
[
"on_grey",
"on_red",
"on_green",
"on_yellow",
"on_blue",
"on_magenta",
"on_cyan",
"on_white",
],
list(range(40, 48)),
)
)
)
HIGHLIGHTS_RE = r"\033\[(?:%s)m" % "|".join(["%d" % v for v in HIGHLIGHTS.values()])
COLORS = dict(
list(
zip(
[
"grey",
"red",
"green",
"yellow",
"blue",
"magenta",
"cyan",
"white",
],
list(range(30, 38)),
)
)
)
COLORS_RE = r"\033\[(?:%s)m" % "|".join(["%d" % v for v in COLORS.values()])
RESET = r"\033[0m"
RESET_RE = r"\033\[0m"
def colored(text, color=None, on_color=None, attrs=None, ansi_code=None):
"""Colorize text, while stripping nested ANSI color sequences.
Author: Konstantin Lepa <konstantin.lepa@gmail.com> / termcolor
Available text colors:
red, green, yellow, blue, magenta, cyan, white.
Available text highlights:
on_red, on_green, on_yellow, on_blue, on_magenta, on_cyan, on_white.
Available attributes:
bold, dark, underline, blink, reverse, concealed.
Example:
colored('Hello, World!', 'red', 'on_grey', ['blue', 'blink'])
colored('Hello, World!', 'green')
"""
if os.getenv("ANSI_COLORS_DISABLED") is None:
if ansi_code is not None:
return "\033[38;5;{}m{}\033[0m".format(ansi_code, text)
fmt_str = "\033[%dm%s"
if color is not None:
text = re.sub(COLORS_RE + "(.*?)" + RESET_RE, r"\1", text)
text = fmt_str % (COLORS[color], text)
if on_color is not None:
text = re.sub(HIGHLIGHTS_RE + "(.*?)" + RESET_RE, r"\1", text)
text = fmt_str % (HIGHLIGHTS[on_color], text)
if attrs is not None:
text = re.sub(ATTRIBUTES_RE + "(.*?)" + RESET_RE, r"\1", text)
for attr in attrs:
text = fmt_str % (ATTRIBUTES[attr], text)
return text + RESET
else:
return text
def deprecation(self, message, *args, **kws):
"""Show a deprecation warning."""
self._log(DEPRECATION, message, args, **kws)
def once(self, message, *args, **kws):
"""Show a message only once, determined by position in source or identifer.
This will not work in IPython or Jupyter notebooks if no identifier is
specified, since then the determined position in source contains the
execution number of the input (cell), which changes every time.
Set a unique ``identifier=X``, otherwise the message will be printed every
time.
"""
identifier = kws.pop("identifier", None)
if identifier is None:
caller = getframeinfo(stack()[1][0])
identifier = "%s:%d" % (caller.filename, caller.lineno)
if not hasattr(self, "once_dict"):
self.once_dict = {}
if identifier in self.once_dict:
return
self.once_dict[identifier] = True
self._log(ONCE, message, args, **kws)
logging.Logger.deprecation = deprecation
logging.Logger.once = once
if supports_color():
logging.addLevelName(
logging.INFO, "\033[1;32m%s\033[1;0m" % logging.getLevelName(logging.INFO)
)
logging.addLevelName(
logging.DEBUG, "\033[1;34m%s\033[1;0m" % logging.getLevelName(logging.DEBUG)
)
logging.addLevelName(
logging.WARNING, "\033[1;33m%s\033[1;0m" % logging.getLevelName(logging.WARNING)
)
logging.addLevelName(
logging.ERROR, "\033[1;31m%s\033[1;0m" % logging.getLevelName(logging.ERROR)
)
logging.addLevelName(
logging.CRITICAL,
"\033[1;101m%s\033[1;0m" % logging.getLevelName(logging.CRITICAL),
)
logging.addLevelName(DEPRECATION, "\033[1;35m%s\033[1;0m" % "DEPRECATION")
logging.addLevelName(ONCE, "\033[1;36m%s\033[1;0m" % "ONCE")
class LogIO(object):
"""Read/write logging information."""
def __init__(self, node, stream, url="pi2089.physik.uni-erlangen.de", port=28777):
self.node = node
self.stream = stream
self.url = url
self.port = port
self.sock = None
self.connect()
def send(self, message, level="info"):
message_string = "+log|{0}|{1}|{2}|{3}\r\n".format(
self.stream, self.node, level, message
)
try:
self.sock.send(message_string)
except socket.error:
print("Lost connection, reconnecting...")
self.connect()
self.sock.send(message_string)
def connect(self):
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.sock.connect((self.url, self.port))
def get_logger(name, filename=None, stream_loglevel="INFO", file_loglevel="DEBUG"):
"""Helper function to get a logger"""
if name in loggers:
return loggers[name]
logger = logging.getLogger(name)
logger.propagate = False
with_color = supports_color()
pre1, suf1 = hash_coloured_escapes(name) if with_color else ("", "")
pre2, suf2 = hash_coloured_escapes(name + "salt") if with_color else ("", "")
formatter = logging.Formatter(
"%(asctime)s %(levelname)s {}+{}+{} "
"%(name)s: %(message)s".format(pre1, pre2, suf1),
datefmt="%Y-%m-%d %H:%M:%S",
)
if filename is not None:
ch_file = logging.handlers.RotatingFileHandler(
filename, maxBytes=5 * 1024 * 1024, backupCount=10
)
ch_file.setLevel(file_loglevel)
ch_file.setFormatter(formatter)
logger.addHandler(ch_file)
ch = logging.StreamHandler()
ch.setLevel(stream_loglevel)
ch.setFormatter(formatter)
logger.addHandler(ch)
loggers[name] = logger
logger.once_dict = {}
return logger
def available_loggers():
"""Return a list of avialable logger names"""
return list(logging.Logger.manager.loggerDict.keys())
def set_level(log_or_name, level):
"""Set the log level for given logger and all handlers"""
if isinstance(log_or_name, str):
log = get_logger(log_or_name)
else:
log = log_or_name
log.setLevel(level)
for handler in log.handlers:
handler.setLevel(level)
def get_printer(name, color=None, ansi_code=None, force_color=False):
"""Return a function which prints a message with a coloured name prefix"""
if force_color or supports_color():
if color is None and ansi_code is None:
cpre_1, csuf_1 = hash_coloured_escapes(name)
cpre_2, csuf_2 = hash_coloured_escapes(name + "salt")
name = cpre_1 + "+" + cpre_2 + "+" + csuf_1 + " " + name
else:
name = colored(name, color=color, ansi_code=ansi_code)
prefix = name + ": "
def printer(text):
print(prefix + str(text))
return printer
def hash_coloured(text):
"""Return a ANSI coloured text based on its hash"""
ansi_code = int(sha256(text.encode("utf-8")).hexdigest(), 16) % 230
return colored(text, ansi_code=ansi_code)
def hash_coloured_escapes(text):
"""Return the ANSI hash colour prefix and suffix for a given text"""
ansi_code = int(sha256(text.encode("utf-8")).hexdigest(), 16) % 230
prefix, suffix = colored("SPLIT", ansi_code=ansi_code).split("SPLIT")
return prefix, suffix
...@@ -5,4 +5,4 @@ Unit tests for the controlhost package. ...@@ -5,4 +5,4 @@ Unit tests for the controlhost package.
""" """
__author__ = 'tamasgal' __author__ = "tamasgal"
...@@ -4,12 +4,13 @@ ...@@ -4,12 +4,13 @@
Unit tests for the controlhost module. Unit tests for the controlhost module.
""" """
import unittest
from controlhost import Tag, Message, Prefix from controlhost import Tag, Message, Prefix
from unittest import TestCase
__author__ = "Tamas Gal" __author__ = "Tamas Gal"
__copyright__ = "Copyright 2018, Tamas Gal and the KM3NeT collaboration." __copyright__ = "Copyright 2016, Tamas Gal and the KM3NeT collaboration."
__credits__ = [] __credits__ = []
__license__ = "MIT" __license__ = "MIT"
__maintainer__ = "Tamas Gal" __maintainer__ = "Tamas Gal"
...@@ -17,35 +18,35 @@ __email__ = "tgal@km3net.de" ...@@ -17,35 +18,35 @@ __email__ = "tgal@km3net.de"
__status__ = "Development" __status__ = "Development"
class TestTag(unittest.TestCase): class TestTag(TestCase):
def test_empty_tag_has_correct_length(self): def test_empty_tag_has_correct_length(self):
tag = Tag() tag = Tag()
self.assertEqual(Tag.SIZE, len(tag)) self.assertEqual(Tag.SIZE, len(tag))
def test_tag_has_correct_length(self): def test_tag_has_correct_length(self):
for tag_name in (b'foo', b'bar', b'baz', b'1'): for tag_name in (b"foo", b"bar", b"baz", b"1"):
tag = Tag(tag_name) tag = Tag(tag_name)
self.assertEqual(Tag.SIZE, len(tag)) self.assertEqual(Tag.SIZE, len(tag))
def test_tag_with_invalid_length_raises_valueerror(self): def test_tag_with_invalid_length_raises_valueerror(self):
self.assertRaises(ValueError, Tag, '123456789') self.assertRaises(ValueError, Tag, "123456789")
def test_tag_has_correct_data(self): def test_tag_has_correct_data(self):
tag = Tag(b'foo') tag = Tag(b"foo")
self.assertEqual(b'foo\x00\x00\x00\x00\x00', tag.data) self.assertEqual(b"foo\x00\x00\x00\x00\x00", tag.data)
tag = Tag('abcdefgh') tag = Tag("abcdefgh")
self.assertEqual('abcdefgh', tag.data) self.assertEqual("abcdefgh", tag.data)
def test_tag_has_correct_string_representation(self): def test_tag_has_correct_string_representation(self):
tag = Tag(b'foo') tag = Tag(b"foo")
self.assertEqual('foo', str(tag)) self.assertEqual("foo", str(tag))
class TestPrefix(unittest.TestCase): class TestPrefix(TestCase):
def test_init(self): def test_init(self):
Prefix(b'foo', 1) Prefix(b"foo", 1)
class TestMessage(unittest.TestCase): class TestMessage(TestCase):
def test_init(self): def test_init(self):
Message('') Message("")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment