Commit 80a5bd8f authored by Chris Snijder's avatar Chris Snijder 🏅
Browse files

Make HAProxy socket timeout configurable, move some code around for readability.

parent 24c91067
......@@ -33,11 +33,13 @@ import configargparse
import logging
import logging.handlers
import os
import sys
import daemon
import stapled
import stapled.core.daemon
import stapled.core.excepthandler
from stapled.core.excepthandler import handle_file_error
from stapled.core.exceptions import ArgumentError
from stapled.util.haproxy import parse_haproxy_config
from stapled.colourlog import ColourFormatter
from stapled.version import __version__, __app_name__
......@@ -222,6 +224,19 @@ def get_cli_arg_parser():
"specified in the config file."
)
)
parser.add(
'--haproxy-socket-timeout',
type=int,
default=3600,
metavar="TIMEOUT <seconds, minimum: 10>",
help=(
"HAProxy sockets are kept open for performance reasons, you can "
"set the amount of seconds sockets should remain open "
"(default=3600). Note that a short amount of time is required to "
"to pass messages to HAProxy, so 10 seconds if the minimum "
"accepted value."
)
)
parser.add(
'--haproxy-config',
type=str,
......@@ -311,39 +326,14 @@ def init():
:func:`stapled.core.daemon.run()` either in daemonised mode if the ``-d``
argument was supplied, or in the current context if ``-d`` wasn't supplied.
"""
parser = get_cli_arg_parser()
args = parser.parse_args()
log_file_handles = __init_logging(args)
args = __get_validated_args()
# Parse the cert_paths argument
arg_cert_paths = __get_arg_cert_paths(args)
# Parse haproxy_sockets argument.
arg_haproxy_sockets = __get_arg_haproxy_sockets(args)
# Make a mapping from certificate paths to sockets in a dict.
haproxy_socket_mapping = dict(zip(arg_cert_paths, arg_haproxy_sockets))
# Parse HAProxy config files.
try:
conf_cert_paths, conf_haproxy_sockets = parse_haproxy_config(
args.haproxy_config
)
except (OSError, IOError) as exc:
logger.critical(handle_file_error(exc))
exit(1)
log_file_handles = __init_logging(args)
# Combine the socket and certificate paths of the arguments and config
# files in the sockets dictionary.
for i, paths in enumerate(conf_cert_paths):
for path in paths:
if path in haproxy_socket_mapping:
haproxy_socket_mapping[path] = unique(
haproxy_socket_mapping[path] + conf_haproxy_sockets[i],
preserve_order=False
)
else:
haproxy_socket_mapping[path] = conf_haproxy_sockets[i]
logger.debug("Paths to socket mapping: %s", str(haproxy_socket_mapping))
# Get a mapping of configured sockets and certificate directories from:
# haproxy config, stapled config and command line arguments
haproxy_socket_mapping = __get_haproxy_socket_mapping(args)
# Now sockets' keys are the merged cert paths from arguments and haproxy
# config files, de-duplicated.
......@@ -356,6 +346,7 @@ def init():
daemon_kwargs = dict(
cert_paths=cert_paths,
haproxy_socket_mapping=haproxy_socket_mapping,
haproxy_socket_timeout=args.haproxy_socket_timeout,
file_extensions=args.file_extensions,
renewal_threads=args.renewal_threads,
refresh_interval=args.refresh_interval,
......@@ -460,6 +451,67 @@ def __init_logging(args):
logger.addHandler(syslog_handler)
return log_file_handles
def __get_haproxy_socket_mapping(args):
"""
Get a mapping of configured sockets and certificate directories from:
haproxy config, stapled config and command line arguments.
:param Namespace args: Argparser argument list.
:return dict Of cert-paths and sockets for inform of changes.
"""
# Parse the cert_paths argument
arg_cert_paths = __get_arg_cert_paths(args)
# Parse haproxy_sockets argument.
arg_haproxy_sockets = __get_arg_haproxy_sockets(args)
# Make a mapping from certificate paths to sockets in a dict.
mapping = dict(zip(arg_cert_paths, arg_haproxy_sockets))
# Parse HAProxy config files.
try:
conf_cert_paths, conf_haproxy_sockets = parse_haproxy_config(
args.haproxy_config
)
except (OSError, IOError) as exc:
logger.critical(handle_file_error(exc))
exit(1)
# Combine the socket and certificate paths of the arguments and config
# files in the sockets dictionary.
for i, paths in enumerate(conf_cert_paths):
for path in paths:
if path in mapping:
mapping[path] = unique(
mapping[path] + conf_haproxy_sockets[i],
preserve_order=False
)
else:
mapping[path] = conf_haproxy_sockets[i]
logger.debug("Paths to socket mapping: %s", str(mapping))
return mapping
def __get_validated_args():
"""
Check that arguments make sense.
Checks should match the restrictions in the usage help messages.
:returns Namespace: Validated argparser argument list.
"""
parser = get_cli_arg_parser()
args = parser.parse_args()
try:
if args.haproxy_socket_timeout < 10:
raise ArgumentError(
"`--haproxy-socket-timeout` should be higher than 10."
)
except ArgumentError as exc:
parser.print_usage(sys.stderr)
logger.critical(
"Invalid command line argument or value: {}".format(exc)
)
exit(1)
return args
if __name__ == '__main__':
try:
......
......@@ -92,6 +92,7 @@ class Stapledaemon(object):
self.haproxy_socket_mapping = kwargs.pop(
'haproxy_socket_mapping', None
)
self.haproxy_socket_timeout = kwargs.pop('haproxy_socket_timeout')
self.file_extensions = kwargs.pop('file_extensions')
self.file_extensions = self.file_extensions.replace(" ", "").split(",")
self.renewal_threads = kwargs.pop('renewal_threads')
......@@ -170,6 +171,7 @@ class Stapledaemon(object):
name="proxy-adder",
thread_object=StapleAdder,
haproxy_socket_mapping=self.haproxy_socket_mapping,
haproxy_socket_timeout=self.haproxy_socket_timeout,
scheduler=self.scheduler
)
......
......@@ -60,3 +60,9 @@ class CertValidationError(Exception):
"""
pass
class ArgumentError(Exception):
"""
Raised when a command line argument has an invalid value.
"""
pass
......@@ -13,7 +13,7 @@ import stapled.core.exceptions
LOG = logging.getLogger(__name__)
SOCKET_BUFFER_SIZE = 1024
SOCKET_TIMEOUT = 86400
SOCKET_TIMEOUT = 300
class StapleAdder(threading.Thread):
......@@ -41,12 +41,6 @@ class StapleAdder(threading.Thread):
#: the base64 encoded OCSP staple
OCSP_ADD = 'set ssl ocsp-response {}'
#: Predefines commands to send to sockets just after opening them.
CONNECT_COMMANDS = [
"prompt",
"set timeout cli {}".format(SOCKET_TIMEOUT)
]
def __init__(self, *args, **kwargs):
"""
Initialise the thread and its parent :class:`threading.Thread`.
......@@ -65,11 +59,21 @@ class StapleAdder(threading.Thread):
self.haproxy_socket_mapping = kwargs.pop(
'haproxy_socket_mapping', None
)
self.haproxy_socket_timeout = kwargs.pop(
'haproxy_socket_timeout', None
)
assert self.scheduler is not None, \
"Please pass a scheduler to get and add proxy-add tasks."
assert self.haproxy_socket_mapping is not None, \
"The StapleAdder needs a haproxy_socket_mapping dict"
assert self.haproxy_socket_timeout is not None, \
"No timeout defined for haproxy socket connection."
# Predefines commands to send to sockets just after opening them.
self.connect_commands = [
"prompt",
"set timeout cli {}".format(self.haproxy_socket_timeout)
]
self.socks = {}
for paths in self.haproxy_socket_mapping.values():
......@@ -119,7 +123,7 @@ class StapleAdder(threading.Thread):
try:
sock.connect(path)
result = []
for command in self.CONNECT_COMMANDS:
for command in self.connect_commands:
result.extend(self._send(sock, command))
# Results (index 1) come per path (index 0), we need only results
result = [res[1] for res in result]
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment