Commit 894b1163 authored by Chris Snijder's avatar Chris Snijder 🏅
Browse files

Merge branch '62-allow-running-one-off-for-debugging-purposes' into 'master'

Resolve "Allow running one-off for debugging purposes"

Closes #62

See merge request !44
parents ea51be2c f5996909
Pipeline #6268 passed with stages
in 7 minutes and 13 seconds
...@@ -86,9 +86,7 @@ test:stretch: ...@@ -86,9 +86,7 @@ test:stretch:
- openssl version - openssl version
- apt-get install -y -q ./dist/stapled_*all.deb - apt-get install -y -q ./dist/stapled_*all.deb
- /refresh_testdata.sh - /refresh_testdata.sh
- stapled -p /tmp/testdata/ --recursive --interactive --no-haproxy-sockets -vvvv & - stapled -p /tmp/testdata/ --recursive --interactive --no-haproxy-sockets -vvvv --one-off
- sleep 15
- ls /tmp/testdata/**/chain.pem.ocsp
dependencies: dependencies:
- build:package - build:package
...@@ -102,6 +100,4 @@ source:dev-setup: ...@@ -102,6 +100,4 @@ source:dev-setup:
- openssl version - openssl version
- pip3 install -e . - pip3 install -e .
- ./refresh_testdata.sh - ./refresh_testdata.sh
- stapled -p /tmp/testdata/ --recursive --interactive --no-haproxy-sockets -vvvv & - stapled -p /tmp/testdata/ --recursive --interactive --no-haproxy-sockets -vvvv --one-off
- sleep 15
- ls /tmp/testdata/**/chain.pem.ocsp
...@@ -92,11 +92,11 @@ haproxy-sockets=[/var/run/haproxy/admin.sock] ...@@ -92,11 +92,11 @@ haproxy-sockets=[/var/run/haproxy/admin.sock]
;; merged in the path to socket mapping. ;; merged in the path to socket mapping.
; haproxy-config=/etc/haproxy/haproxy.cfg ; haproxy-config=/etc/haproxy/haproxy.cfg
;; Set a keep alive time in seconds after wich the connection to the HAProxy ;; Set a keep alive time in seconds after which the connection to the HAProxy
;; sockets is terminated. The minimum allowed value is 10 seconds, because ;; sockets is terminated. The minimum allowed value is 1 second, because
;; stapled will take at least a bit of time to communicate with HAProxy, and ;; stapled will take at least a bit of time to communicate with HAProxy, and
;; either process could be "busy". ;; either process could be "busy".
; haproxy-socket-keepalive=3600 ; haproxy-socket-keepalive=10
;; Don't output anything to stdout, can be used together with `logdir` ;; Don't output anything to stdout, can be used together with `logdir`
......
# -*- coding: utf-8 -*-
""" """
Initialise the stapled module. Initialise the stapled module.
......
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding: utf-8 -*-
""" """
Parse command line arguments and starts the OCSP Staple daemon. Parse command line arguments and starts the OCSP Staple daemon.
...@@ -39,12 +38,12 @@ import daemon ...@@ -39,12 +38,12 @@ import daemon
import stapled import stapled
import stapled.core.daemon import stapled.core.daemon
import stapled.core.excepthandler import stapled.core.excepthandler
from stapled.core.excepthandler import handle_file_error
from stapled.core.exceptions import ArgumentError from stapled.core.exceptions import ArgumentError
from stapled.util.haproxy import parse_haproxy_config from stapled.util.haproxy import parse_haproxy_config
from stapled.colourlog import ColourFormatter from stapled.colourlog import ColourFormatter
from stapled.version import __version__, __app_name__ from stapled.version import __version__, __app_name__
from stapled.util.functions import unique from stapled.util.functions import unique
from stapled.util.exitcode import ExitCodeTracker
#: :attr:`logging.format` format string for log files and syslog #: :attr:`logging.format` format string for log files and syslog
LOGFORMAT = ( LOGFORMAT = (
...@@ -228,13 +227,13 @@ def get_cli_arg_parser(): ...@@ -228,13 +227,13 @@ def get_cli_arg_parser():
parser.add( parser.add(
'--haproxy-socket-keepalive', '--haproxy-socket-keepalive',
type=int, type=int,
default=3600, default=10,
metavar="KEEP-ALIVE <seconds, minimum: 10>", metavar="KEEP-ALIVE <seconds, minimum: 10>",
help=( help=(
"HAProxy sockets are kept open for performance reasons, you can " "HAProxy sockets are kept open for performance reasons, you can "
"set the amount of seconds sockets should remain open " "set the amount of seconds sockets should remain open "
"(default=3600). Note that a short amount of time is required to " "(default=10). Note that a short amount of time is required to "
"to pass messages to HAProxy, so 10 seconds if the minimum " "to pass messages to HAProxy, so 1 second if the minimum "
"accepted value." "accepted value."
) )
) )
...@@ -316,6 +315,16 @@ def get_cli_arg_parser(): ...@@ -316,6 +315,16 @@ def get_cli_arg_parser():
"DEPRECATED, please see ``--cert-paths``." "DEPRECATED, please see ``--cert-paths``."
) )
) )
parser.add(
'--one-off',
action='store_true',
default=False,
help=(
"Index cert_paths and fetch staples only once and then exit. "
"This overrides the --refresh-interval argument. The --daemon and "
"--no-daemon arguments are also ignored."
)
)
return parser return parser
...@@ -329,7 +338,7 @@ def init(): ...@@ -329,7 +338,7 @@ def init():
""" """
args = __get_validated_args() args = __get_validated_args()
log_file_handles = __init_logging(args) log_file_handles, exit_code_tracker = __init_logging(args)
# Get a mapping of configured sockets and certificate directories from: # Get a mapping of configured sockets and certificate directories from:
# haproxy config, stapled config and command line arguments # haproxy config, stapled config and command line arguments
...@@ -350,10 +359,12 @@ def init(): ...@@ -350,10 +359,12 @@ def init():
file_extensions=args.file_extensions, file_extensions=args.file_extensions,
renewal_threads=args.renewal_threads, renewal_threads=args.renewal_threads,
refresh_interval=args.refresh_interval, refresh_interval=args.refresh_interval,
one_off=args.one_off,
minimum_validity=args.minimum_validity, minimum_validity=args.minimum_validity,
recursive=args.recursive, recursive=args.recursive,
no_recycle=args.no_recycle, no_recycle=args.no_recycle,
ignore=args.ignore ignore=args.ignore,
exit_code_tracker=exit_code_tracker
) )
if stapled.LOCAL_LIB_MODE: if stapled.LOCAL_LIB_MODE:
...@@ -449,12 +460,19 @@ def __init_logging(args): ...@@ -449,12 +460,19 @@ def __init_logging(args):
logging.Formatter(LOGFORMAT, TIMESTAMP_FORMAT) logging.Formatter(LOGFORMAT, TIMESTAMP_FORMAT)
) )
logger.addHandler(syslog_handler) logger.addHandler(syslog_handler)
return log_file_handles if args.one_off:
# Keep track of errors so we can return a greater than 0 exit code when
# errors occurred.
exit_code_tracker = ExitCodeTracker(logging.WARN)
logger.addHandler(exit_code_tracker)
else:
exit_code_tracker = None
return log_file_handles, exit_code_tracker
def __get_haproxy_socket_mapping(args): def __get_haproxy_socket_mapping(args):
""" """
Get mapping of configured sockets and certificate directories. Get a mapping of configured sockets and certificate directories.
From: haproxy config, stapled config and command line arguments. From: haproxy config, stapled config and command line arguments.
...@@ -473,8 +491,8 @@ def __get_haproxy_socket_mapping(args): ...@@ -473,8 +491,8 @@ def __get_haproxy_socket_mapping(args):
conf_cert_paths, conf_haproxy_sockets = parse_haproxy_config( conf_cert_paths, conf_haproxy_sockets = parse_haproxy_config(
args.haproxy_config args.haproxy_config
) )
except (OSError, IOError) as exc: except (OSError) as exc:
logger.critical(handle_file_error(exc)) logger.critical(exc)
exit(1) exit(1)
# Combine the socket and certificate paths of the arguments and config # Combine the socket and certificate paths of the arguments and config
...@@ -495,7 +513,7 @@ def __get_haproxy_socket_mapping(args): ...@@ -495,7 +513,7 @@ def __get_haproxy_socket_mapping(args):
def __get_validated_args(): def __get_validated_args():
""" """
Parse and validate CLI arguments and configuration. Check that arguments make sense.
Checks should match the restrictions in the usage help messages. Checks should match the restrictions in the usage help messages.
...@@ -504,17 +522,62 @@ def __get_validated_args(): ...@@ -504,17 +522,62 @@ def __get_validated_args():
parser = get_cli_arg_parser() parser = get_cli_arg_parser()
args = parser.parse_args() args = parser.parse_args()
try: try:
if args.haproxy_socket_keepalive < 10: if args.haproxy_socket_keepalive < 1:
raise ArgumentError( raise ArgumentError(
"`--haproxy-socket-keepalive` should be higher than 10." "`--haproxy-socket-keepalive` should be 1 or higher."
) )
except ArgumentError as exc: except ArgumentError as exc:
parser.print_usage(sys.stderr) parser.print_usage(sys.stderr)
logger.critical("Invalid command line argument or value: %s", exc) logger.critical("Invalid command line argument or value: %s", exc)
exit(1) exit(1)
# Run in one-off mode, run once then exit.
if args.one_off:
args.refresh_interval = None
args.daemon = False
return args return args
def __get_haproxy_socket_mapping(args):
"""
Get mapping of configured sockets and certificate directories.
From: haproxy config, stapled config and command line arguments.
:param Namespace args: Argparser argument list.
:return dict Of cert-paths and sockets for inform of changes.
"""
# Parse the cert_paths argument
arg_cert_paths = __get_arg_cert_paths(args)
# Parse haproxy_sockets argument.
arg_haproxy_sockets = __get_arg_haproxy_sockets(args)
# Make a mapping from certificate paths to sockets in a dict.
mapping = dict(zip(arg_cert_paths, arg_haproxy_sockets))
# Parse HAProxy config files.
try:
conf_cert_paths, conf_haproxy_sockets = parse_haproxy_config(
args.haproxy_config
)
except (OSError) as exc:
logger.critical(exc)
exit(1)
# Combine the socket and certificate paths of the arguments and config
# files in the sockets dictionary.
for i, paths in enumerate(conf_cert_paths):
for path in paths:
if path in mapping:
mapping[path] = unique(
mapping[path] + conf_haproxy_sockets[i],
preserve_order=False
)
else:
mapping[path] = conf_haproxy_sockets[i]
logger.debug("Paths to socket mapping: %s", str(mapping))
return mapping
if __name__ == '__main__': if __name__ == '__main__':
try: try:
init() init()
......
# -*- coding: utf-8 -*-
""" """
ANSI colourise the logging stream (works on LINUX/UNIX based systems). ANSI colourise the logging stream (works on LINUX/UNIX based systems).
......
# -*- coding: utf-8 -*-
""" """
Test the ColourFormatter class when run directly. Test the ColourFormatter class when run directly.
""" """
......
# -*- coding: utf-8 -*-
""" """
This module locates certificate files in the supplied paths and parses Locate certificate files in the supplied paths and parse them.
them. It then keeps track of the following:
It also keeps track of the following:
- If cert is found for the first time (thus also when the daemon is started), - If cert is found for the first time (thus also when the daemon is started),
the cert is added to the :attr:`stapled.core.certfinder.CertFinder.scheduler` the cert is added to the :attr:`stapled.core.certfinder.CertFinder.scheduler`
...@@ -28,7 +28,6 @@ import logging ...@@ -28,7 +28,6 @@ import logging
import fnmatch import fnmatch
import os import os
import errno import errno
import stapled
from stapled.core.excepthandler import stapled_except_handle from stapled.core.excepthandler import stapled_except_handle
from stapled.core.taskcontext import StapleTaskContext from stapled.core.taskcontext import StapleTaskContext
from stapled.core.certmodel import CertModel from stapled.core.certmodel import CertModel
...@@ -39,7 +38,8 @@ LOG = logging.getLogger(__name__) ...@@ -39,7 +38,8 @@ LOG = logging.getLogger(__name__)
class CertFinderThread(threading.Thread): class CertFinderThread(threading.Thread):
""" """
This searches paths for certificate files. A thread that searches paths for certificate files.
When found, models are created for the certificate files, which are wrapped When found, models are created for the certificate files, which are wrapped
in a :class:`stapled.core.taskcontext.StapleTaskContext` which are then in a :class:`stapled.core.taskcontext.StapleTaskContext` which are then
scheduled to be processed by the scheduled to be processed by the
...@@ -48,19 +48,18 @@ class CertFinderThread(threading.Thread): ...@@ -48,19 +48,18 @@ class CertFinderThread(threading.Thread):
Pass ``refresh_interval=None`` if you want to run it only once (e.g. for Pass ``refresh_interval=None`` if you want to run it only once (e.g. for
testing) testing)
""" """
# pylint: disable=too-many-instance-attributes # pylint: disable=too-many-instance-attributes
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
""" """
Initialise the thread with its parent :class:`threading.Thread` and its Initialise with parent :class:`threading.Thread` and its arguments.
arguments.
:kwarg dict models: A dict to maintain a model cache **(required)**. :kwarg dict models: A dict to maintain a model cache **(required)**.
:kwarg iter cert_paths: The paths to index **(required)**. :kwarg iter cert_paths: The paths to index **(required)**.
:kwarg stapled.scheduling.SchedulerThread scheduler: The scheduler :kwarg stapled.scheduling.SchedulerThread scheduler: The scheduler
object where we add new parse tasks to. **(required)**. object where we add new parse tasks to. **(required)**.
:kwarg int refresh_interval: The minimum amount of time (s) :kwarg int refresh_interval: The minimum amount of time (s) between
between search runs, defaults to 10 seconds. Set to None to run search runs. Set to None (default) to run once **(optional)**.
only once **(optional)**.
:kwarg array file_extensions: An array containing the file extensions :kwarg array file_extensions: An array containing the file extensions
of file types to check for certificate content **(optional)**. of file types to check for certificate content **(optional)**.
""" """
...@@ -68,12 +67,8 @@ class CertFinderThread(threading.Thread): ...@@ -68,12 +67,8 @@ class CertFinderThread(threading.Thread):
self.models = kwargs.pop('models', None) self.models = kwargs.pop('models', None)
self.cert_paths = kwargs.pop('cert_paths', None) self.cert_paths = kwargs.pop('cert_paths', None)
self.scheduler = kwargs.pop('scheduler', None) self.scheduler = kwargs.pop('scheduler', None)
self.refresh_interval = kwargs.pop( self.refresh_interval = kwargs.pop('refresh_interval', None)
'refresh_interval', stapled.DEFAULT_REFRESH_INTERVAL self.file_extensions = kwargs.pop('file_extensions', None)
)
self.file_extensions = kwargs.pop(
'file_extensions', stapled.FILE_EXTENSIONS_DEFAULT
)
self.last_refresh = None self.last_refresh = None
self.ignore = kwargs.pop('ignore', []) or [] self.ignore = kwargs.pop('ignore', []) or []
self.recursive = kwargs.pop('recursive', False) self.recursive = kwargs.pop('recursive', False)
...@@ -84,13 +79,21 @@ class CertFinderThread(threading.Thread): ...@@ -84,13 +79,21 @@ class CertFinderThread(threading.Thread):
assert self.cert_paths is not None, \ assert self.cert_paths is not None, \
"At least one path should be passed for indexing." "At least one path should be passed for indexing."
assert self.file_extensions is not None, \
"Please specify file extensions to search for certificates."
assert self.scheduler is not None, \ assert self.scheduler is not None, \
"Please pass a scheduler to get tasks from and add tasks to." "Please pass a scheduler to get tasks from and add tasks to."
super(CertFinderThread, self).__init__(*args, **kwargs) super(CertFinderThread, self).__init__(*args, **kwargs)
def run(self): def run(self):
"""Start the certificate finder thread.""" """
Start the certificate finder thread.
The "scheduling" mentioned in this method does not use the scheduler.
It will sleep instead, only because it is simpler.
"""
LOG.info("Scanning paths: '%s'", "', '".join(self.cert_paths)) LOG.info("Scanning paths: '%s'", "', '".join(self.cert_paths))
while not self.stop: while not self.stop:
# Catch any exceptions within this context to protect the thread. # Catch any exceptions within this context to protect the thread.
...@@ -103,7 +106,7 @@ class CertFinderThread(threading.Thread): ...@@ -103,7 +106,7 @@ class CertFinderThread(threading.Thread):
since_last = time.time() - self.last_refresh since_last = time.time() - self.last_refresh
# Check if the last refresh took longer than the interval.. # Check if the last refresh took longer than the interval..
if since_last > self.refresh_interval: if since_last > self.refresh_interval:
# It did so start right now.. # It did take longer than the interval so, start right now
LOG.info( LOG.info(
"Starting a new refresh immediately because the last " "Starting a new refresh immediately because the last "
"refresh took %0.3f seconds while the minimum " "refresh took %0.3f seconds while the minimum "
...@@ -131,6 +134,8 @@ class CertFinderThread(threading.Thread): ...@@ -131,6 +134,8 @@ class CertFinderThread(threading.Thread):
def refresh(self): def refresh(self):
""" """
Refresh the index.
Wrap up the internal :meth:`CertFinder._update_cached_certs()` and Wrap up the internal :meth:`CertFinder._update_cached_certs()` and
:meth:`CertFinder._find_new_certs()` functions. :meth:`CertFinder._find_new_certs()` functions.
...@@ -167,7 +172,7 @@ class CertFinderThread(threading.Thread): ...@@ -167,7 +172,7 @@ class CertFinderThread(threading.Thread):
dirs = [] dirs = []
try: try:
dirs = os.listdir(path) dirs = os.listdir(path)
except (OSError, IOError) as exc: except (OSError) as exc:
# If a path is actually a file we can still use it.. # If a path is actually a file we can still use it..
if exc.errno == errno.ENOTDIR and os.path.isfile(path): if exc.errno == errno.ENOTDIR and os.path.isfile(path):
LOG.debug("%s may be a single file", path) LOG.debug("%s may be a single file", path)
...@@ -206,7 +211,7 @@ class CertFinderThread(threading.Thread): ...@@ -206,7 +211,7 @@ class CertFinderThread(threading.Thread):
sched_time=None sched_time=None
) )
self.scheduler.add_task(context) self.scheduler.add_task(context)
except (IOError, OSError) as exc: except (OSError) as exc:
# If the directory is unreadable this gets printed at every # If the directory is unreadable this gets printed at every
# refresh until the directory is readable. We catch this here # refresh until the directory is readable. We catch this here
# so any readable directory can still be scanned. # so any readable directory can still be scanned.
......
# -*- coding: utf-8 -*-
""" """
This module defines the :class:`stapled.core.certmodel.CertModel` class which is This module defines the :class:`stapled.core.certmodel.CertModel` class which is
used to keep track of certificates that are found by the used to keep track of certificates that are found by the
......
# -*- coding: utf-8 -*-
""" """
This module parses certificate in a queue so the data contained in the This module parses certificate in a queue so the data contained in the
certificate can be used to request OCSP responses. After parsing a new certificate can be used to request OCSP responses. After parsing a new
......
# -*- coding: utf-8 -*-
""" """
This module bootstraps the stapled process by starting threads for: This module bootstraps the stapled process by starting threads for:
...@@ -57,7 +56,7 @@ from stapled.core.certfinder import CertFinderThread ...@@ -57,7 +56,7 @@ from stapled.core.certfinder import CertFinderThread
from stapled.core.certparser import CertParserThread from stapled.core.certparser import CertParserThread
from stapled.core.staplerenewer import StapleRenewerThread from stapled.core.staplerenewer import StapleRenewerThread
from stapled.core.stapleadder import StapleAdder from stapled.core.stapleadder import StapleAdder
from stapled.scheduling import SchedulerThread from stapled.scheduling import SchedulerThread, QueueError
from stapled import MAX_RESTART_THREADS from stapled import MAX_RESTART_THREADS
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
...@@ -79,7 +78,7 @@ class Stapledaemon(object): ...@@ -79,7 +78,7 @@ class Stapledaemon(object):
:kwarg list file_extensions: List of file extensions to search for :kwarg list file_extensions: List of file extensions to search for
certificates. certificates.
:kwarg int renewal_threads: Amount of staple renewal threads. :kwarg int renewal_threads: Amount of staple renewal threads.
:kwarg int refresh_interval: Interval between re-indexing of :kwarg NoneType|int refresh_interval: Interval between re-indexing of
certificate paths. certificate paths.
:kwarg int minimum_validity: Minimum validity of stapled before :kwarg int minimum_validity: Minimum validity of stapled before
renewing. renewing.
...@@ -97,9 +96,11 @@ class Stapledaemon(object): ...@@ -97,9 +96,11 @@ class Stapledaemon(object):
self.file_extensions = self.file_extensions.replace(" ", "").split(",") self.file_extensions = self.file_extensions.replace(" ", "").split(",")
self.renewal_threads = kwargs.pop('renewal_threads') self.renewal_threads = kwargs.pop('renewal_threads')
self.refresh_interval = kwargs.pop('refresh_interval') self.refresh_interval = kwargs.pop('refresh_interval')
self.one_off = kwargs.pop('one_off')
self.minimum_validity = kwargs.pop('minimum_validity') self.minimum_validity = kwargs.pop('minimum_validity')
self.recursive = kwargs.pop('recursive') self.recursive = kwargs.pop('recursive')
self.no_recycle = kwargs.pop('no_recycle') self.no_recycle = kwargs.pop('no_recycle')
self.exit_code_tracker = kwargs.pop('exit_code_tracker')
self.ignore = [] self.ignore = []
rel_path_re = re.compile(r'^\.+\/') rel_path_re = re.compile(r'^\.+\/')
...@@ -136,21 +137,24 @@ class Stapledaemon(object): ...@@ -136,21 +137,24 @@ class Stapledaemon(object):
# Scheduler thread # Scheduler thread
self.scheduler = self.start_scheduler_thread() self.scheduler = self.start_scheduler_thread()
self.staple_adder = None
# Start proxy adder thread if sockets were supplied # Start proxy adder thread if sockets were supplied
if self.haproxy_socket_mapping: if self.haproxy_socket_mapping:
self.start_staple_adder_thread() self.staple_adder = self.start_staple_adder_thread()
# Start ocsp response gathering threads # Start ocsp response gathering threads
threads_list = [] self.renewers = []
for tid in range(0, self.renewal_threads): for tid in range(0, self.renewal_threads):
threads_list.append(self.start_renewer_thread(tid)) self.renewers.append(self.start_renewer_thread(tid))
# Start certificate parser thread # Start certificate parser thread
self.parser = self.start_parser_thread() self.parser = self.start_parser_thread()
# Start certificate finding thread # Start certificate finding thread
self.finder = self.start_finder_thread() self.finder = self.start_finder_thread()
if self.one_off:
self.monitor_threads() self.handle_one_off()
else:
self.monitor_threads()
def exit_gracefully(self, signum, _frame): def exit_gracefully(self, signum, _frame):
"""Set self.stop so the main thread stops.""" """Set self.stop so the main thread stops."""
...@@ -261,6 +265,65 @@ class Stapledaemon(object): ...@@ -261,6 +265,65 @@ class Stapledaemon(object):
pass # cannot join current thread pass # cannot join current thread
LOG.info("Stopping daemon thread") LOG.info("Stopping daemon thread")
def handle_one_off(self):
"""
Stop threads that are done so we can do a one-off run.
- When the certfinder is done and the parsing queue is empty, we can
end the certparser thread.
- When the certparser is done and and the renewal queue is empty, we
can end the staplerenewers.