diff --git a/bin/clear_rabbit_queues b/bin/clear_rabbit_queues index 503fa07d8156..29298f18f3b5 100755 --- a/bin/clear_rabbit_queues +++ b/bin/clear_rabbit_queues @@ -74,6 +74,7 @@ if __name__ == '__main__': utils.default_flagfile() args = flags.FLAGS(sys.argv) logging.setup() + rpc.register_opts(flags.FLAGS) delete_queues(args[1:]) if FLAGS.delete_exchange: delete_exchange(FLAGS.control_exchange) diff --git a/bin/nova-dhcpbridge b/bin/nova-dhcpbridge index 162fdd5a6e1f..a3cbe168c803 100755 --- a/bin/nova-dhcpbridge +++ b/bin/nova-dhcpbridge @@ -99,6 +99,8 @@ def main(): argv = FLAGS(sys.argv) logging.setup() + rpc.register_opts(FLAGS) + if int(os.environ.get('TESTING', '0')): from nova.tests import fake_flags diff --git a/bin/nova-manage b/bin/nova-manage index bd68578188f7..dbb22d7945b4 100755 --- a/bin/nova-manage +++ b/bin/nova-manage @@ -1675,6 +1675,8 @@ def main(): except Exception: print 'sudo failed, continuing as if nothing happened' + rpc.register_opts(FLAGS) + try: argv = FLAGS(sys.argv) logging.setup() @@ -1683,7 +1685,6 @@ def main(): print _('Please re-run nova-manage as root.') sys.exit(2) raise - script_name = argv.pop(0) if len(argv) < 1: print _("\nOpenStack Nova version: %(version)s (%(vcs)s)\n") % \ diff --git a/nova/rpc/__init__.py b/nova/rpc/__init__.py index 4acc5634cb1c..45d8c00b2d2d 100644 --- a/nova/rpc/__init__.py +++ b/nova/rpc/__init__.py @@ -17,17 +17,37 @@ # License for the specific language governing permissions and limitations # under the License. -from nova import flags from nova.openstack.common import cfg from nova import utils -rpc_backend_opt = cfg.StrOpt('rpc_backend', - default='nova.rpc.impl_kombu', - help="The messaging module to use, defaults to kombu.") +rpc_opts = [ + cfg.StrOpt('rpc_backend', + default='nova.rpc.impl_kombu', + help="The messaging module to use, defaults to kombu."), + cfg.IntOpt('rpc_thread_pool_size', + default=64, + help='Size of RPC thread pool'), + cfg.IntOpt('rpc_conn_pool_size', + default=30, + help='Size of RPC connection pool'), + cfg.IntOpt('rpc_response_timeout', + default=60, + help='Seconds to wait for a response from call or multicall'), + cfg.IntOpt('allowed_rpc_exception_modules', + default=['nova.exception'], + help='Modules of exceptions that are permitted to be recreated' + 'upon receiving exception data from an rpc call.'), + ] -FLAGS = flags.FLAGS -FLAGS.register_opt(rpc_backend_opt) +_CONF = None + + +def register_opts(conf): + global _CONF + _CONF = conf + _CONF.register_opts(rpc_opts) + _get_impl().register_opts(_CONF) def create_connection(new=True): @@ -43,7 +63,7 @@ def create_connection(new=True): :returns: An instance of nova.rpc.common.Connection """ - return _get_impl().create_connection(new=new) + return _get_impl().create_connection(_CONF, new=new) def call(context, topic, msg, timeout=None): @@ -65,7 +85,7 @@ def call(context, topic, msg, timeout=None): :raises: nova.rpc.common.Timeout if a complete response is not received before the timeout is reached. """ - return _get_impl().call(context, topic, msg, timeout) + return _get_impl().call(_CONF, context, topic, msg, timeout) def cast(context, topic, msg): @@ -82,7 +102,7 @@ def cast(context, topic, msg): :returns: None """ - return _get_impl().cast(context, topic, msg) + return _get_impl().cast(_CONF, context, topic, msg) def fanout_cast(context, topic, msg): @@ -102,7 +122,7 @@ def fanout_cast(context, topic, msg): :returns: None """ - return _get_impl().fanout_cast(context, topic, msg) + return _get_impl().fanout_cast(_CONF, context, topic, msg) def multicall(context, topic, msg, timeout=None): @@ -131,7 +151,7 @@ def multicall(context, topic, msg, timeout=None): :raises: nova.rpc.common.Timeout if a complete response is not received before the timeout is reached. """ - return _get_impl().multicall(context, topic, msg, timeout) + return _get_impl().multicall(_CONF, context, topic, msg, timeout) def notify(context, topic, msg): @@ -144,7 +164,7 @@ def notify(context, topic, msg): :returns: None """ - return _get_impl().notify(context, topic, msg) + return _get_impl().notify(_CONF, context, topic, msg) def cleanup(): @@ -172,7 +192,8 @@ def cast_to_server(context, server_params, topic, msg): :returns: None """ - return _get_impl().cast_to_server(context, server_params, topic, msg) + return _get_impl().cast_to_server(_CONF, context, server_params, topic, + msg) def fanout_cast_to_server(context, server_params, topic, msg): @@ -187,16 +208,16 @@ def fanout_cast_to_server(context, server_params, topic, msg): :returns: None """ - return _get_impl().fanout_cast_to_server(context, server_params, topic, - msg) + return _get_impl().fanout_cast_to_server(_CONF, context, server_params, + topic, msg) _RPCIMPL = None def _get_impl(): - """Delay import of rpc_backend until FLAGS are loaded.""" + """Delay import of rpc_backend until configuration is loaded.""" global _RPCIMPL if _RPCIMPL is None: - _RPCIMPL = utils.import_object(FLAGS.rpc_backend) + _RPCIMPL = utils.import_object(_CONF.rpc_backend) return _RPCIMPL diff --git a/nova/rpc/amqp.py b/nova/rpc/amqp.py index ac29a625d948..d58806c9e644 100644 --- a/nova/rpc/amqp.py +++ b/nova/rpc/amqp.py @@ -31,10 +31,10 @@ import uuid from eventlet import greenpool from eventlet import pools +from eventlet import semaphore from nova import context from nova import exception -from nova import flags from nova import log as logging from nova.openstack.common import local import nova.rpc.common as rpc_common @@ -43,27 +43,36 @@ from nova import utils LOG = logging.getLogger(__name__) -FLAGS = flags.FLAGS - - class Pool(pools.Pool): """Class that implements a Pool of Connections.""" - def __init__(self, *args, **kwargs): - self.connection_cls = kwargs.pop("connection_cls", None) - kwargs.setdefault("max_size", FLAGS.rpc_conn_pool_size) + def __init__(self, conf, connection_cls, *args, **kwargs): + self.connection_cls = connection_cls + self.conf = conf + kwargs.setdefault("max_size", self.conf.rpc_conn_pool_size) kwargs.setdefault("order_as_stack", True) super(Pool, self).__init__(*args, **kwargs) # TODO(comstud): Timeout connections not used in a while def create(self): LOG.debug('Pool creating new connection') - return self.connection_cls() + return self.connection_cls(self.conf) def empty(self): while self.free_items: self.get().close() +_pool_create_sem = semaphore.Semaphore() + + +def get_connection_pool(conf, connection_cls): + with _pool_create_sem: + # Make sure only one thread tries to create the connection pool. + if not connection_cls.pool: + connection_cls.pool = Pool(conf, connection_cls) + return connection_cls.pool + + class ConnectionContext(rpc_common.Connection): """The class that is actually returned to the caller of create_connection(). This is a essentially a wrapper around @@ -75,14 +84,15 @@ class ConnectionContext(rpc_common.Connection): the pool. """ - def __init__(self, connection_pool, pooled=True, server_params=None): + def __init__(self, conf, connection_pool, pooled=True, server_params=None): """Create a new connection, or get one from the pool""" self.connection = None + self.conf = conf self.connection_pool = connection_pool if pooled: self.connection = connection_pool.get() else: - self.connection = connection_pool.connection_cls( + self.connection = connection_pool.connection_cls(conf, server_params=server_params) self.pooled = pooled @@ -133,13 +143,14 @@ class ConnectionContext(rpc_common.Connection): raise exception.InvalidRPCConnectionReuse() -def msg_reply(msg_id, connection_pool, reply=None, failure=None, ending=False): +def msg_reply(conf, msg_id, connection_pool, reply=None, failure=None, + ending=False): """Sends a reply or an error on the channel signified by msg_id. Failure should be a sys.exc_info() tuple. """ - with ConnectionContext(connection_pool) as conn: + with ConnectionContext(conf, connection_pool) as conn: if failure: failure = rpc_common.serialize_remote_exception(failure) @@ -158,18 +169,19 @@ class RpcContext(context.RequestContext): """Context that supports replying to a rpc.call""" def __init__(self, *args, **kwargs): self.msg_id = kwargs.pop('msg_id', None) + self.conf = kwargs.pop('conf') super(RpcContext, self).__init__(*args, **kwargs) def reply(self, reply=None, failure=None, ending=False, connection_pool=None): if self.msg_id: - msg_reply(self.msg_id, connection_pool, reply, failure, + msg_reply(self.conf, self.msg_id, connection_pool, reply, failure, ending) if ending: self.msg_id = None -def unpack_context(msg): +def unpack_context(conf, msg): """Unpack context from msg.""" context_dict = {} for key in list(msg.keys()): @@ -180,6 +192,7 @@ def unpack_context(msg): value = msg.pop(key) context_dict[key[9:]] = value context_dict['msg_id'] = msg.pop('_msg_id', None) + context_dict['conf'] = conf ctx = RpcContext.from_dict(context_dict) rpc_common._safe_log(LOG.debug, _('unpacked context: %s'), ctx.to_dict()) return ctx @@ -202,10 +215,11 @@ def pack_context(msg, context): class ProxyCallback(object): """Calls methods on a proxy object based on method and args.""" - def __init__(self, proxy, connection_pool): + def __init__(self, conf, proxy, connection_pool): self.proxy = proxy - self.pool = greenpool.GreenPool(FLAGS.rpc_thread_pool_size) + self.pool = greenpool.GreenPool(conf.rpc_thread_pool_size) self.connection_pool = connection_pool + self.conf = conf def __call__(self, message_data): """Consumer callback to call a method on a proxy object. @@ -225,7 +239,7 @@ class ProxyCallback(object): if hasattr(local.store, 'context'): del local.store.context rpc_common._safe_log(LOG.debug, _('received %s'), message_data) - ctxt = unpack_context(message_data) + ctxt = unpack_context(self.conf, message_data) method = message_data.get('method') args = message_data.get('args', {}) if not method: @@ -262,13 +276,14 @@ class ProxyCallback(object): class MulticallWaiter(object): - def __init__(self, connection, timeout): + def __init__(self, conf, connection, timeout): self._connection = connection self._iterator = connection.iterconsume( - timeout=timeout or FLAGS.rpc_response_timeout) + timeout=timeout or conf.rpc_response_timeout) self._result = None self._done = False self._got_ending = False + self._conf = conf def done(self): if self._done: @@ -282,7 +297,8 @@ class MulticallWaiter(object): """The consume() callback will call this. Store the result.""" if data['failure']: failure = data['failure'] - self._result = rpc_common.deserialize_remote_exception(failure) + self._result = rpc_common.deserialize_remote_exception(self._conf, + failure) elif data.get('ending', False): self._got_ending = True @@ -309,12 +325,12 @@ class MulticallWaiter(object): yield result -def create_connection(new, connection_pool): +def create_connection(conf, new, connection_pool): """Create a connection""" - return ConnectionContext(connection_pool, pooled=not new) + return ConnectionContext(conf, connection_pool, pooled=not new) -def multicall(context, topic, msg, timeout, connection_pool): +def multicall(conf, context, topic, msg, timeout, connection_pool): """Make a call that returns multiple times.""" # Can't use 'with' for multicall, as it returns an iterator # that will continue to use the connection. When it's done, @@ -326,16 +342,16 @@ def multicall(context, topic, msg, timeout, connection_pool): LOG.debug(_('MSG_ID is %s') % (msg_id)) pack_context(msg, context) - conn = ConnectionContext(connection_pool) - wait_msg = MulticallWaiter(conn, timeout) + conn = ConnectionContext(conf, connection_pool) + wait_msg = MulticallWaiter(conf, conn, timeout) conn.declare_direct_consumer(msg_id, wait_msg) conn.topic_send(topic, msg) return wait_msg -def call(context, topic, msg, timeout, connection_pool): +def call(conf, context, topic, msg, timeout, connection_pool): """Sends a message on a topic and wait for a response.""" - rv = multicall(context, topic, msg, timeout, connection_pool) + rv = multicall(conf, context, topic, msg, timeout, connection_pool) # NOTE(vish): return the last result from the multicall rv = list(rv) if not rv: @@ -343,47 +359,48 @@ def call(context, topic, msg, timeout, connection_pool): return rv[-1] -def cast(context, topic, msg, connection_pool): +def cast(conf, context, topic, msg, connection_pool): """Sends a message on a topic without waiting for a response.""" LOG.debug(_('Making asynchronous cast on %s...'), topic) pack_context(msg, context) - with ConnectionContext(connection_pool) as conn: + with ConnectionContext(conf, connection_pool) as conn: conn.topic_send(topic, msg) -def fanout_cast(context, topic, msg, connection_pool): +def fanout_cast(conf, context, topic, msg, connection_pool): """Sends a message on a fanout exchange without waiting for a response.""" LOG.debug(_('Making asynchronous fanout cast...')) pack_context(msg, context) - with ConnectionContext(connection_pool) as conn: + with ConnectionContext(conf, connection_pool) as conn: conn.fanout_send(topic, msg) -def cast_to_server(context, server_params, topic, msg, connection_pool): +def cast_to_server(conf, context, server_params, topic, msg, connection_pool): """Sends a message on a topic to a specific server.""" pack_context(msg, context) - with ConnectionContext(connection_pool, pooled=False, + with ConnectionContext(conf, connection_pool, pooled=False, server_params=server_params) as conn: conn.topic_send(topic, msg) -def fanout_cast_to_server(context, server_params, topic, msg, +def fanout_cast_to_server(conf, context, server_params, topic, msg, connection_pool): """Sends a message on a fanout exchange to a specific server.""" pack_context(msg, context) - with ConnectionContext(connection_pool, pooled=False, + with ConnectionContext(conf, connection_pool, pooled=False, server_params=server_params) as conn: conn.fanout_send(topic, msg) -def notify(context, topic, msg, connection_pool): +def notify(conf, context, topic, msg, connection_pool): """Sends a notification event on a topic.""" event_type = msg.get('event_type') LOG.debug(_('Sending %(event_type)s on %(topic)s'), locals()) pack_context(msg, context) - with ConnectionContext(connection_pool) as conn: + with ConnectionContext(conf, connection_pool) as conn: conn.notify_send(topic, msg) def cleanup(connection_pool): - connection_pool.empty() + if connection_pool: + connection_pool.empty() diff --git a/nova/rpc/common.py b/nova/rpc/common.py index a2975e9a57f5..0b9eebf0fa54 100644 --- a/nova/rpc/common.py +++ b/nova/rpc/common.py @@ -22,7 +22,6 @@ import sys import traceback from nova import exception -from nova import flags from nova import log as logging from nova.openstack.common import cfg from nova import utils @@ -30,25 +29,6 @@ from nova import utils LOG = logging.getLogger(__name__) -rpc_opts = [ - cfg.IntOpt('rpc_thread_pool_size', - default=64, - help='Size of RPC thread pool'), - cfg.IntOpt('rpc_conn_pool_size', - default=30, - help='Size of RPC connection pool'), - cfg.IntOpt('rpc_response_timeout', - default=60, - help='Seconds to wait for a response from call or multicall'), - cfg.IntOpt('allowed_rpc_exception_modules', - default=['nova.exception'], - help='Modules of exceptions that are permitted to be recreated' - 'upon receiving exception data from an rpc call.'), - ] - -flags.FLAGS.register_opts(rpc_opts) -FLAGS = flags.FLAGS - class RemoteError(exception.NovaException): """Signifies that a remote class has raised an exception. @@ -95,7 +75,7 @@ class Connection(object): """ raise NotImplementedError() - def create_consumer(self, topic, proxy, fanout=False): + def create_consumer(self, conf, topic, proxy, fanout=False): """Create a consumer on this connection. A consumer is associated with a message queue on the backend message @@ -104,6 +84,7 @@ class Connection(object): off of the queue will determine which method gets called on the proxy object. + :param conf: An openstack.common.cfg configuration object. :param topic: This is a name associated with what to consume from. Multiple instances of a service may consume from the same topic. For example, all instances of nova-compute consume @@ -197,7 +178,7 @@ def serialize_remote_exception(failure_info): return json_data -def deserialize_remote_exception(data): +def deserialize_remote_exception(conf, data): failure = utils.loads(str(data)) trace = failure.get('tb', []) @@ -207,7 +188,7 @@ def deserialize_remote_exception(data): # NOTE(ameade): We DO NOT want to allow just any module to be imported, in # order to prevent arbitrary code execution. - if not module in FLAGS.allowed_rpc_exception_modules: + if not module in conf.allowed_rpc_exception_modules: return RemoteError(name, failure.get('message'), trace) try: diff --git a/nova/rpc/impl_fake.py b/nova/rpc/impl_fake.py index 43aed15c2643..065cca699e7e 100644 --- a/nova/rpc/impl_fake.py +++ b/nova/rpc/impl_fake.py @@ -27,13 +27,10 @@ import traceback import eventlet from nova import context -from nova import flags from nova.rpc import common as rpc_common CONSUMERS = {} -FLAGS = flags.FLAGS - class RpcContext(context.RequestContext): def __init__(self, *args, **kwargs): @@ -116,7 +113,7 @@ class Connection(object): pass -def create_connection(new=True): +def create_connection(conf, new=True): """Create a connection""" return Connection() @@ -126,7 +123,7 @@ def check_serialize(msg): json.dumps(msg) -def multicall(context, topic, msg, timeout=None): +def multicall(conf, context, topic, msg, timeout=None): """Make a call that returns multiple times.""" check_serialize(msg) @@ -144,9 +141,9 @@ def multicall(context, topic, msg, timeout=None): return consumer.call(context, method, args, timeout) -def call(context, topic, msg, timeout=None): +def call(conf, context, topic, msg, timeout=None): """Sends a message on a topic and wait for a response.""" - rv = multicall(context, topic, msg, timeout) + rv = multicall(conf, context, topic, msg, timeout) # NOTE(vish): return the last result from the multicall rv = list(rv) if not rv: @@ -154,14 +151,14 @@ def call(context, topic, msg, timeout=None): return rv[-1] -def cast(context, topic, msg): +def cast(conf, context, topic, msg): try: - call(context, topic, msg) + call(conf, context, topic, msg) except Exception: pass -def notify(context, topic, msg): +def notify(conf, context, topic, msg): check_serialize(msg) @@ -169,7 +166,7 @@ def cleanup(): pass -def fanout_cast(context, topic, msg): +def fanout_cast(conf, context, topic, msg): """Cast to all consumers of a topic""" check_serialize(msg) method = msg.get('method') @@ -182,3 +179,7 @@ def fanout_cast(context, topic, msg): consumer.call(context, method, args, None) except Exception: pass + + +def register_opts(conf): + pass diff --git a/nova/rpc/impl_kombu.py b/nova/rpc/impl_kombu.py index 676aec57240e..6ff87646ca56 100644 --- a/nova/rpc/impl_kombu.py +++ b/nova/rpc/impl_kombu.py @@ -28,7 +28,6 @@ import kombu.entity import kombu.messaging import kombu.connection -from nova import flags from nova.openstack.common import cfg from nova.rpc import amqp as rpc_amqp from nova.rpc import common as rpc_common @@ -49,8 +48,6 @@ kombu_opts = [ '(valid only if SSL enabled)')), ] -FLAGS = flags.FLAGS -FLAGS.register_opts(kombu_opts) LOG = rpc_common.LOG @@ -126,7 +123,7 @@ class ConsumerBase(object): class DirectConsumer(ConsumerBase): """Queue/consumer class for 'direct'""" - def __init__(self, channel, msg_id, callback, tag, **kwargs): + def __init__(self, conf, channel, msg_id, callback, tag, **kwargs): """Init a 'direct' queue. 'channel' is the amqp channel to use @@ -159,7 +156,7 @@ class DirectConsumer(ConsumerBase): class TopicConsumer(ConsumerBase): """Consumer class for 'topic'""" - def __init__(self, channel, topic, callback, tag, **kwargs): + def __init__(self, conf, channel, topic, callback, tag, **kwargs): """Init a 'topic' queue. 'channel' is the amqp channel to use @@ -170,12 +167,12 @@ class TopicConsumer(ConsumerBase): Other kombu options may be passed """ # Default options - options = {'durable': FLAGS.rabbit_durable_queues, + options = {'durable': conf.rabbit_durable_queues, 'auto_delete': False, 'exclusive': False} options.update(kwargs) exchange = kombu.entity.Exchange( - name=FLAGS.control_exchange, + name=conf.control_exchange, type='topic', durable=options['durable'], auto_delete=options['auto_delete']) @@ -192,7 +189,7 @@ class TopicConsumer(ConsumerBase): class FanoutConsumer(ConsumerBase): """Consumer class for 'fanout'""" - def __init__(self, channel, topic, callback, tag, **kwargs): + def __init__(self, conf, channel, topic, callback, tag, **kwargs): """Init a 'fanout' queue. 'channel' is the amqp channel to use @@ -252,7 +249,7 @@ class Publisher(object): class DirectPublisher(Publisher): """Publisher class for 'direct'""" - def __init__(self, channel, msg_id, **kwargs): + def __init__(self, conf, channel, msg_id, **kwargs): """init a 'direct' publisher. Kombu options may be passed as keyword args to override defaults @@ -271,17 +268,17 @@ class DirectPublisher(Publisher): class TopicPublisher(Publisher): """Publisher class for 'topic'""" - def __init__(self, channel, topic, **kwargs): + def __init__(self, conf, channel, topic, **kwargs): """init a 'topic' publisher. Kombu options may be passed as keyword args to override defaults """ - options = {'durable': FLAGS.rabbit_durable_queues, + options = {'durable': conf.rabbit_durable_queues, 'auto_delete': False, 'exclusive': False} options.update(kwargs) super(TopicPublisher, self).__init__(channel, - FLAGS.control_exchange, + conf.control_exchange, topic, type='topic', **options) @@ -289,7 +286,7 @@ class TopicPublisher(Publisher): class FanoutPublisher(Publisher): """Publisher class for 'fanout'""" - def __init__(self, channel, topic, **kwargs): + def __init__(self, conf, channel, topic, **kwargs): """init a 'fanout' publisher. Kombu options may be passed as keyword args to override defaults @@ -308,9 +305,9 @@ class FanoutPublisher(Publisher): class NotifyPublisher(TopicPublisher): """Publisher class for 'notify'""" - def __init__(self, *args, **kwargs): - self.durable = kwargs.pop('durable', FLAGS.rabbit_durable_queues) - super(NotifyPublisher, self).__init__(*args, **kwargs) + def __init__(self, conf, channel, topic, **kwargs): + self.durable = kwargs.pop('durable', conf.rabbit_durable_queues) + super(NotifyPublisher, self).__init__(conf, channel, topic, **kwargs) def reconnect(self, channel): super(NotifyPublisher, self).reconnect(channel) @@ -329,15 +326,18 @@ class NotifyPublisher(TopicPublisher): class Connection(object): """Connection object.""" - def __init__(self, server_params=None): + pool = None + + def __init__(self, conf, server_params=None): self.consumers = [] self.consumer_thread = None - self.max_retries = FLAGS.rabbit_max_retries + self.conf = conf + self.max_retries = self.conf.rabbit_max_retries # Try forever? if self.max_retries <= 0: self.max_retries = None - self.interval_start = FLAGS.rabbit_retry_interval - self.interval_stepping = FLAGS.rabbit_retry_backoff + self.interval_start = self.conf.rabbit_retry_interval + self.interval_stepping = self.conf.rabbit_retry_backoff # max retry-interval = 30 seconds self.interval_max = 30 self.memory_transport = False @@ -353,21 +353,21 @@ class Connection(object): p_key = server_params_to_kombu_params.get(sp_key, sp_key) params[p_key] = value - params.setdefault('hostname', FLAGS.rabbit_host) - params.setdefault('port', FLAGS.rabbit_port) - params.setdefault('userid', FLAGS.rabbit_userid) - params.setdefault('password', FLAGS.rabbit_password) - params.setdefault('virtual_host', FLAGS.rabbit_virtual_host) + params.setdefault('hostname', self.conf.rabbit_host) + params.setdefault('port', self.conf.rabbit_port) + params.setdefault('userid', self.conf.rabbit_userid) + params.setdefault('password', self.conf.rabbit_password) + params.setdefault('virtual_host', self.conf.rabbit_virtual_host) self.params = params - if FLAGS.fake_rabbit: + if self.conf.fake_rabbit: self.params['transport'] = 'memory' self.memory_transport = True else: self.memory_transport = False - if FLAGS.rabbit_use_ssl: + if self.conf.rabbit_use_ssl: self.params['ssl'] = self._fetch_ssl_params() self.connection = None @@ -379,14 +379,14 @@ class Connection(object): ssl_params = dict() # http://docs.python.org/library/ssl.html - ssl.wrap_socket - if FLAGS.kombu_ssl_version: - ssl_params['ssl_version'] = FLAGS.kombu_ssl_version - if FLAGS.kombu_ssl_keyfile: - ssl_params['keyfile'] = FLAGS.kombu_ssl_keyfile - if FLAGS.kombu_ssl_certfile: - ssl_params['certfile'] = FLAGS.kombu_ssl_certfile - if FLAGS.kombu_ssl_ca_certs: - ssl_params['ca_certs'] = FLAGS.kombu_ssl_ca_certs + if self.conf.kombu_ssl_version: + ssl_params['ssl_version'] = self.conf.kombu_ssl_version + if self.conf.kombu_ssl_keyfile: + ssl_params['keyfile'] = self.conf.kombu_ssl_keyfile + if self.conf.kombu_ssl_certfile: + ssl_params['certfile'] = self.conf.kombu_ssl_certfile + if self.conf.kombu_ssl_ca_certs: + ssl_params['ca_certs'] = self.conf.kombu_ssl_ca_certs # We might want to allow variations in the # future with this? ssl_params['cert_reqs'] = ssl.CERT_REQUIRED @@ -534,7 +534,7 @@ class Connection(object): "%(err_str)s") % log_info) def _declare_consumer(): - consumer = consumer_cls(self.channel, topic, callback, + consumer = consumer_cls(self.conf, self.channel, topic, callback, self.consumer_num.next()) self.consumers.append(consumer) return consumer @@ -590,7 +590,7 @@ class Connection(object): "'%(topic)s': %(err_str)s") % log_info) def _publish(): - publisher = cls(self.channel, topic, **kwargs) + publisher = cls(self.conf, self.channel, topic, **kwargs) publisher.send(msg) self.ensure(_error_callback, _publish) @@ -648,58 +648,66 @@ class Connection(object): def create_consumer(self, topic, proxy, fanout=False): """Create a consumer that calls a method in a proxy object""" + proxy_cb = rpc_amqp.ProxyCallback(self.conf, proxy, + rpc_amqp.get_connection_pool(self, Connection)) + if fanout: - self.declare_fanout_consumer(topic, - rpc_amqp.ProxyCallback(proxy, Connection.pool)) + self.declare_fanout_consumer(topic, proxy_cb) else: - self.declare_topic_consumer(topic, - rpc_amqp.ProxyCallback(proxy, Connection.pool)) + self.declare_topic_consumer(topic, proxy_cb) -Connection.pool = rpc_amqp.Pool(connection_cls=Connection) - - -def create_connection(new=True): +def create_connection(conf, new=True): """Create a connection""" - return rpc_amqp.create_connection(new, Connection.pool) + return rpc_amqp.create_connection(conf, new, + rpc_amqp.get_connection_pool(conf, Connection)) -def multicall(context, topic, msg, timeout=None): +def multicall(conf, context, topic, msg, timeout=None): """Make a call that returns multiple times.""" - return rpc_amqp.multicall(context, topic, msg, timeout, Connection.pool) + return rpc_amqp.multicall(conf, context, topic, msg, timeout, + rpc_amqp.get_connection_pool(conf, Connection)) -def call(context, topic, msg, timeout=None): +def call(conf, context, topic, msg, timeout=None): """Sends a message on a topic and wait for a response.""" - return rpc_amqp.call(context, topic, msg, timeout, Connection.pool) + return rpc_amqp.call(conf, context, topic, msg, timeout, + rpc_amqp.get_connection_pool(conf, Connection)) -def cast(context, topic, msg): +def cast(conf, context, topic, msg): """Sends a message on a topic without waiting for a response.""" - return rpc_amqp.cast(context, topic, msg, Connection.pool) + return rpc_amqp.cast(conf, context, topic, msg, + rpc_amqp.get_connection_pool(conf, Connection)) -def fanout_cast(context, topic, msg): +def fanout_cast(conf, context, topic, msg): """Sends a message on a fanout exchange without waiting for a response.""" - return rpc_amqp.fanout_cast(context, topic, msg, Connection.pool) + return rpc_amqp.fanout_cast(conf, context, topic, msg, + rpc_amqp.get_connection_pool(conf, Connection)) -def cast_to_server(context, server_params, topic, msg): +def cast_to_server(conf, context, server_params, topic, msg): """Sends a message on a topic to a specific server.""" - return rpc_amqp.cast_to_server(context, server_params, topic, msg, - Connection.pool) + return rpc_amqp.cast_to_server(conf, context, server_params, topic, msg, + rpc_amqp.get_connection_pool(conf, Connection)) -def fanout_cast_to_server(context, server_params, topic, msg): +def fanout_cast_to_server(conf, context, server_params, topic, msg): """Sends a message on a fanout exchange to a specific server.""" - return rpc_amqp.cast_to_server(context, server_params, topic, msg, - Connection.pool) + return rpc_amqp.cast_to_server(conf, context, server_params, topic, msg, + rpc_amqp.get_connection_pool(conf, Connection)) -def notify(context, topic, msg): +def notify(conf, context, topic, msg): """Sends a notification event on a topic.""" - return rpc_amqp.notify(context, topic, msg, Connection.pool) + return rpc_amqp.notify(conf, context, topic, msg, + rpc_amqp.get_connection_pool(conf, Connection)) def cleanup(): return rpc_amqp.cleanup(Connection.pool) + + +def register_opts(conf): + conf.register_opts(kombu_opts) diff --git a/nova/rpc/impl_qpid.py b/nova/rpc/impl_qpid.py index c275246b0fd6..37bb62430264 100644 --- a/nova/rpc/impl_qpid.py +++ b/nova/rpc/impl_qpid.py @@ -25,7 +25,6 @@ import greenlet import qpid.messaging import qpid.messaging.exceptions -from nova import flags from nova import log as logging from nova.openstack.common import cfg from nova.rpc import amqp as rpc_amqp @@ -78,9 +77,6 @@ qpid_opts = [ help='Disable Nagle algorithm'), ] -FLAGS = flags.FLAGS -FLAGS.register_opts(qpid_opts) - class ConsumerBase(object): """Consumer base class.""" @@ -147,7 +143,7 @@ class ConsumerBase(object): class DirectConsumer(ConsumerBase): """Queue/consumer class for 'direct'""" - def __init__(self, session, msg_id, callback): + def __init__(self, conf, session, msg_id, callback): """Init a 'direct' queue. 'session' is the amqp session to use @@ -165,7 +161,7 @@ class DirectConsumer(ConsumerBase): class TopicConsumer(ConsumerBase): """Consumer class for 'topic'""" - def __init__(self, session, topic, callback): + def __init__(self, conf, session, topic, callback): """Init a 'topic' queue. 'session' is the amqp session to use @@ -174,14 +170,14 @@ class TopicConsumer(ConsumerBase): """ super(TopicConsumer, self).__init__(session, callback, - "%s/%s" % (FLAGS.control_exchange, topic), {}, + "%s/%s" % (conf.control_exchange, topic), {}, topic, {}) class FanoutConsumer(ConsumerBase): """Consumer class for 'fanout'""" - def __init__(self, session, topic, callback): + def __init__(self, conf, session, topic, callback): """Init a 'fanout' queue. 'session' is the amqp session to use @@ -236,7 +232,7 @@ class Publisher(object): class DirectPublisher(Publisher): """Publisher class for 'direct'""" - def __init__(self, session, msg_id): + def __init__(self, conf, session, msg_id): """Init a 'direct' publisher.""" super(DirectPublisher, self).__init__(session, msg_id, {"type": "Direct"}) @@ -244,16 +240,16 @@ class DirectPublisher(Publisher): class TopicPublisher(Publisher): """Publisher class for 'topic'""" - def __init__(self, session, topic): + def __init__(self, conf, session, topic): """init a 'topic' publisher. """ super(TopicPublisher, self).__init__(session, - "%s/%s" % (FLAGS.control_exchange, topic)) + "%s/%s" % (conf.control_exchange, topic)) class FanoutPublisher(Publisher): """Publisher class for 'fanout'""" - def __init__(self, session, topic): + def __init__(self, conf, session, topic): """init a 'fanout' publisher. """ super(FanoutPublisher, self).__init__(session, @@ -262,29 +258,32 @@ class FanoutPublisher(Publisher): class NotifyPublisher(Publisher): """Publisher class for notifications""" - def __init__(self, session, topic): + def __init__(self, conf, session, topic): """init a 'topic' publisher. """ super(NotifyPublisher, self).__init__(session, - "%s/%s" % (FLAGS.control_exchange, topic), + "%s/%s" % (conf.control_exchange, topic), {"durable": True}) class Connection(object): """Connection object.""" - def __init__(self, server_params=None): + pool = None + + def __init__(self, conf, server_params=None): self.session = None self.consumers = {} self.consumer_thread = None + self.conf = conf if server_params is None: server_params = {} - default_params = dict(hostname=FLAGS.qpid_hostname, - port=FLAGS.qpid_port, - username=FLAGS.qpid_username, - password=FLAGS.qpid_password) + default_params = dict(hostname=self.conf.qpid_hostname, + port=self.conf.qpid_port, + username=self.conf.qpid_username, + password=self.conf.qpid_password) params = server_params for key in default_params.keys(): @@ -298,23 +297,25 @@ class Connection(object): # before we call open self.connection.username = params['username'] self.connection.password = params['password'] - self.connection.sasl_mechanisms = FLAGS.qpid_sasl_mechanisms - self.connection.reconnect = FLAGS.qpid_reconnect - if FLAGS.qpid_reconnect_timeout: - self.connection.reconnect_timeout = FLAGS.qpid_reconnect_timeout - if FLAGS.qpid_reconnect_limit: - self.connection.reconnect_limit = FLAGS.qpid_reconnect_limit - if FLAGS.qpid_reconnect_interval_max: + self.connection.sasl_mechanisms = self.conf.qpid_sasl_mechanisms + self.connection.reconnect = self.conf.qpid_reconnect + if self.conf.qpid_reconnect_timeout: + self.connection.reconnect_timeout = ( + self.conf.qpid_reconnect_timeout) + if self.conf.qpid_reconnect_limit: + self.connection.reconnect_limit = self.conf.qpid_reconnect_limit + if self.conf.qpid_reconnect_interval_max: self.connection.reconnect_interval_max = ( - FLAGS.qpid_reconnect_interval_max) - if FLAGS.qpid_reconnect_interval_min: + self.conf.qpid_reconnect_interval_max) + if self.conf.qpid_reconnect_interval_min: self.connection.reconnect_interval_min = ( - FLAGS.qpid_reconnect_interval_min) - if FLAGS.qpid_reconnect_interval: - self.connection.reconnect_interval = FLAGS.qpid_reconnect_interval - self.connection.hearbeat = FLAGS.qpid_heartbeat - self.connection.protocol = FLAGS.qpid_protocol - self.connection.tcp_nodelay = FLAGS.qpid_tcp_nodelay + self.conf.qpid_reconnect_interval_min) + if self.conf.qpid_reconnect_interval: + self.connection.reconnect_interval = ( + self.conf.qpid_reconnect_interval) + self.connection.hearbeat = self.conf.qpid_heartbeat + self.connection.protocol = self.conf.qpid_protocol + self.connection.tcp_nodelay = self.conf.qpid_tcp_nodelay # Open is part of reconnect - # NOTE(WGH) not sure we need this with the reconnect flags @@ -339,7 +340,7 @@ class Connection(object): self.connection.open() except qpid.messaging.exceptions.ConnectionError, e: LOG.error(_('Unable to connect to AMQP server: %s'), e) - time.sleep(FLAGS.qpid_reconnect_interval or 1) + time.sleep(self.conf.qpid_reconnect_interval or 1) else: break @@ -386,7 +387,7 @@ class Connection(object): "%(err_str)s") % log_info) def _declare_consumer(): - consumer = consumer_cls(self.session, topic, callback) + consumer = consumer_cls(self.conf, self.session, topic, callback) self._register_consumer(consumer) return consumer @@ -435,7 +436,7 @@ class Connection(object): "'%(topic)s': %(err_str)s") % log_info) def _publisher_send(): - publisher = cls(self.session, topic) + publisher = cls(self.conf, self.session, topic) publisher.send(msg) return self.ensure(_connect_error, _publisher_send) @@ -493,60 +494,70 @@ class Connection(object): def create_consumer(self, topic, proxy, fanout=False): """Create a consumer that calls a method in a proxy object""" + proxy_cb = rpc_amqp.ProxyCallback(self.conf, proxy, + rpc_amqp.get_connection_pool(self, Connection)) + if fanout: - consumer = FanoutConsumer(self.session, topic, - rpc_amqp.ProxyCallback(proxy, Connection.pool)) + consumer = FanoutConsumer(self.conf, self.session, topic, proxy_cb) else: - consumer = TopicConsumer(self.session, topic, - rpc_amqp.ProxyCallback(proxy, Connection.pool)) + consumer = TopicConsumer(self.conf, self.session, topic, proxy_cb) + self._register_consumer(consumer) + return consumer -Connection.pool = rpc_amqp.Pool(connection_cls=Connection) - - -def create_connection(new=True): +def create_connection(conf, new=True): """Create a connection""" - return rpc_amqp.create_connection(new, Connection.pool) + return rpc_amqp.create_connection(conf, new, + rpc_amqp.get_connection_pool(conf, Connection)) -def multicall(context, topic, msg, timeout=None): +def multicall(conf, context, topic, msg, timeout=None): """Make a call that returns multiple times.""" - return rpc_amqp.multicall(context, topic, msg, timeout, Connection.pool) + return rpc_amqp.multicall(conf, context, topic, msg, timeout, + rpc_amqp.get_connection_pool(conf, Connection)) -def call(context, topic, msg, timeout=None): +def call(conf, context, topic, msg, timeout=None): """Sends a message on a topic and wait for a response.""" - return rpc_amqp.call(context, topic, msg, timeout, Connection.pool) + return rpc_amqp.call(conf, context, topic, msg, timeout, + rpc_amqp.get_connection_pool(conf, Connection)) -def cast(context, topic, msg): +def cast(conf, context, topic, msg): """Sends a message on a topic without waiting for a response.""" - return rpc_amqp.cast(context, topic, msg, Connection.pool) + return rpc_amqp.cast(conf, context, topic, msg, + rpc_amqp.get_connection_pool(conf, Connection)) -def fanout_cast(context, topic, msg): +def fanout_cast(conf, context, topic, msg): """Sends a message on a fanout exchange without waiting for a response.""" - return rpc_amqp.fanout_cast(context, topic, msg, Connection.pool) + return rpc_amqp.fanout_cast(conf, context, topic, msg, + rpc_amqp.get_connection_pool(conf, Connection)) -def cast_to_server(context, server_params, topic, msg): +def cast_to_server(conf, context, server_params, topic, msg): """Sends a message on a topic to a specific server.""" - return rpc_amqp.cast_to_server(context, server_params, topic, msg, - Connection.pool) + return rpc_amqp.cast_to_server(conf, context, server_params, topic, msg, + rpc_amqp.get_connection_pool(conf, Connection)) -def fanout_cast_to_server(context, server_params, topic, msg): +def fanout_cast_to_server(conf, context, server_params, topic, msg): """Sends a message on a fanout exchange to a specific server.""" - return rpc_amqp.fanout_cast_to_server(context, server_params, topic, - msg, Connection.pool) + return rpc_amqp.fanout_cast_to_server(conf, context, server_params, topic, + msg, rpc_amqp.get_connection_pool(conf, Connection)) -def notify(context, topic, msg): +def notify(conf, context, topic, msg): """Sends a notification event on a topic.""" - return rpc_amqp.notify(context, topic, msg, Connection.pool) + return rpc_amqp.notify(conf, context, topic, msg, + rpc_amqp.get_connection_pool(conf, Connection)) def cleanup(): return rpc_amqp.cleanup(Connection.pool) + + +def register_opts(conf): + conf.register_opts(qpid_opts) diff --git a/nova/service.py b/nova/service.py index a351406fe940..c9817bbe8333 100644 --- a/nova/service.py +++ b/nova/service.py @@ -177,6 +177,7 @@ class Service(object): LOG.audit(_('Starting %(topic)s node (version %(vcs_string)s)'), {'topic': self.topic, 'vcs_string': vcs_string}) utils.cleanup_file_locks() + rpc.register_opts(FLAGS) self.manager.init_host() self.model_disconnected = False ctxt = context.get_admin_context() @@ -393,6 +394,7 @@ class WSGIService(object): """ utils.cleanup_file_locks() + rpc.register_opts(FLAGS) if self.manager: self.manager.init_host() self.server.start() diff --git a/nova/tests/__init__.py b/nova/tests/__init__.py index fee29da6cc80..0e33cd7ace66 100644 --- a/nova/tests/__init__.py +++ b/nova/tests/__init__.py @@ -59,6 +59,9 @@ def reset_db(): def setup(): import mox # Fail fast if you don't have mox. Workaround for bug 810424 + from nova import rpc # Register rpc_backend before fake_flags sets it + FLAGS.register_opts(rpc.rpc_opts) + from nova import context from nova import db from nova.db import migration diff --git a/nova/tests/rpc/common.py b/nova/tests/rpc/common.py index 3524e5682800..d04f0561f868 100644 --- a/nova/tests/rpc/common.py +++ b/nova/tests/rpc/common.py @@ -26,19 +26,21 @@ import nose from nova import context from nova import exception +from nova import flags from nova import log as logging from nova.rpc import amqp as rpc_amqp from nova.rpc import common as rpc_common from nova import test +FLAGS = flags.FLAGS LOG = logging.getLogger(__name__) class BaseRpcTestCase(test.TestCase): def setUp(self, supports_timeouts=True): super(BaseRpcTestCase, self).setUp() - self.conn = self.rpc.create_connection(True) + self.conn = self.rpc.create_connection(FLAGS, True) self.receiver = TestReceiver() self.conn.create_consumer('test', self.receiver, False) self.conn.consume_in_thread() @@ -51,20 +53,20 @@ class BaseRpcTestCase(test.TestCase): def test_call_succeed(self): value = 42 - result = self.rpc.call(self.context, 'test', {"method": "echo", - "args": {"value": value}}) + result = self.rpc.call(FLAGS, self.context, 'test', + {"method": "echo", "args": {"value": value}}) self.assertEqual(value, result) def test_call_succeed_despite_multiple_returns_yield(self): value = 42 - result = self.rpc.call(self.context, 'test', + result = self.rpc.call(FLAGS, self.context, 'test', {"method": "echo_three_times_yield", "args": {"value": value}}) self.assertEqual(value + 2, result) def test_multicall_succeed_once(self): value = 42 - result = self.rpc.multicall(self.context, + result = self.rpc.multicall(FLAGS, self.context, 'test', {"method": "echo", "args": {"value": value}}) @@ -75,7 +77,7 @@ class BaseRpcTestCase(test.TestCase): def test_multicall_three_nones(self): value = 42 - result = self.rpc.multicall(self.context, + result = self.rpc.multicall(FLAGS, self.context, 'test', {"method": "multicall_three_nones", "args": {"value": value}}) @@ -86,7 +88,7 @@ class BaseRpcTestCase(test.TestCase): def test_multicall_succeed_three_times_yield(self): value = 42 - result = self.rpc.multicall(self.context, + result = self.rpc.multicall(FLAGS, self.context, 'test', {"method": "echo_three_times_yield", "args": {"value": value}}) @@ -96,7 +98,7 @@ class BaseRpcTestCase(test.TestCase): def test_context_passed(self): """Makes sure a context is passed through rpc call.""" value = 42 - result = self.rpc.call(self.context, + result = self.rpc.call(FLAGS, self.context, 'test', {"method": "context", "args": {"value": value}}) self.assertEqual(self.context.to_dict(), result) @@ -112,7 +114,7 @@ class BaseRpcTestCase(test.TestCase): # TODO(comstud): # so, it will replay the context and use the same REQID? # that's bizarre. - ret = self.rpc.call(context, + ret = self.rpc.call(FLAGS, context, queue, {"method": "echo", "args": {"value": value}}) @@ -120,11 +122,11 @@ class BaseRpcTestCase(test.TestCase): return value nested = Nested() - conn = self.rpc.create_connection(True) + conn = self.rpc.create_connection(FLAGS, True) conn.create_consumer('nested', nested, False) conn.consume_in_thread() value = 42 - result = self.rpc.call(self.context, + result = self.rpc.call(FLAGS, self.context, 'nested', {"method": "echo", "args": {"queue": "test", "value": value}}) @@ -139,12 +141,12 @@ class BaseRpcTestCase(test.TestCase): value = 42 self.assertRaises(rpc_common.Timeout, self.rpc.call, - self.context, + FLAGS, self.context, 'test', {"method": "block", "args": {"value": value}}, timeout=1) try: - self.rpc.call(self.context, + self.rpc.call(FLAGS, self.context, 'test', {"method": "block", "args": {"value": value}}, @@ -169,8 +171,8 @@ class BaseRpcAMQPTestCase(BaseRpcTestCase): self.stubs.Set(rpc_amqp, 'unpack_context', fake_unpack_context) value = 41 - self.rpc.cast(self.context, 'test', {"method": "echo", - "args": {"value": value}}) + self.rpc.cast(FLAGS, self.context, 'test', + {"method": "echo", "args": {"value": value}}) # Wait for the cast to complete. for x in xrange(50): @@ -185,7 +187,7 @@ class BaseRpcAMQPTestCase(BaseRpcTestCase): self.stubs.Set(rpc_amqp, 'unpack_context', orig_unpack) value = 42 - result = self.rpc.call(self.context, 'test', + result = self.rpc.call(FLAGS, self.context, 'test', {"method": "echo", "args": {"value": value}}) self.assertEqual(value, result) diff --git a/nova/tests/rpc/test_common.py b/nova/tests/rpc/test_common.py index 6220bd01a134..4b505db97484 100644 --- a/nova/tests/rpc/test_common.py +++ b/nova/tests/rpc/test_common.py @@ -93,7 +93,7 @@ class RpcCommonTestCase(test.TestCase): } serialized = json.dumps(failure) - after_exc = rpc_common.deserialize_remote_exception(serialized) + after_exc = rpc_common.deserialize_remote_exception(FLAGS, serialized) self.assertTrue(isinstance(after_exc, exception.NovaException)) self.assertTrue('test message' in unicode(after_exc)) #assure the traceback was added @@ -108,7 +108,7 @@ class RpcCommonTestCase(test.TestCase): } serialized = json.dumps(failure) - after_exc = rpc_common.deserialize_remote_exception(serialized) + after_exc = rpc_common.deserialize_remote_exception(FLAGS, serialized) self.assertTrue(isinstance(after_exc, rpc_common.RemoteError)) def test_deserialize_remote_exception_user_defined_exception(self): @@ -121,7 +121,7 @@ class RpcCommonTestCase(test.TestCase): } serialized = json.dumps(failure) - after_exc = rpc_common.deserialize_remote_exception(serialized) + after_exc = rpc_common.deserialize_remote_exception(FLAGS, serialized) self.assertTrue(isinstance(after_exc, FakeUserDefinedException)) #assure the traceback was added self.assertTrue('raise FakeUserDefinedException' in unicode(after_exc)) @@ -141,7 +141,7 @@ class RpcCommonTestCase(test.TestCase): } serialized = json.dumps(failure) - after_exc = rpc_common.deserialize_remote_exception(serialized) + after_exc = rpc_common.deserialize_remote_exception(FLAGS, serialized) self.assertTrue(isinstance(after_exc, rpc_common.RemoteError)) #assure the traceback was added self.assertTrue('raise FakeIDontExistException' in unicode(after_exc)) diff --git a/nova/tests/rpc/test_kombu.py b/nova/tests/rpc/test_kombu.py index 966cb3a6905b..a66857567f72 100644 --- a/nova/tests/rpc/test_kombu.py +++ b/nova/tests/rpc/test_kombu.py @@ -53,6 +53,7 @@ def _raise_exc_stub(stubs, times, obj, method, exc_msg, class RpcKombuTestCase(common.BaseRpcAMQPTestCase): def setUp(self): self.rpc = impl_kombu + impl_kombu.register_opts(FLAGS) super(RpcKombuTestCase, self).setUp() def tearDown(self): @@ -61,10 +62,10 @@ class RpcKombuTestCase(common.BaseRpcAMQPTestCase): def test_reusing_connection(self): """Test that reusing a connection returns same one.""" - conn_context = self.rpc.create_connection(new=False) + conn_context = self.rpc.create_connection(FLAGS, new=False) conn1 = conn_context.connection conn_context.close() - conn_context = self.rpc.create_connection(new=False) + conn_context = self.rpc.create_connection(FLAGS, new=False) conn2 = conn_context.connection conn_context.close() self.assertEqual(conn1, conn2) @@ -72,7 +73,7 @@ class RpcKombuTestCase(common.BaseRpcAMQPTestCase): def test_topic_send_receive(self): """Test sending to a topic exchange/queue""" - conn = self.rpc.create_connection() + conn = self.rpc.create_connection(FLAGS) message = 'topic test message' self.received_message = None @@ -89,7 +90,7 @@ class RpcKombuTestCase(common.BaseRpcAMQPTestCase): def test_direct_send_receive(self): """Test sending to a direct exchange/queue""" - conn = self.rpc.create_connection() + conn = self.rpc.create_connection(FLAGS) message = 'direct test message' self.received_message = None @@ -123,10 +124,10 @@ class RpcKombuTestCase(common.BaseRpcAMQPTestCase): def topic_send(_context, topic, msg): pass - MyConnection.pool = rpc_amqp.Pool(connection_cls=MyConnection) + MyConnection.pool = rpc_amqp.Pool(FLAGS, MyConnection) self.stubs.Set(impl_kombu, 'Connection', MyConnection) - impl_kombu.cast(ctxt, 'fake_topic', {'msg': 'fake'}) + impl_kombu.cast(FLAGS, ctxt, 'fake_topic', {'msg': 'fake'}) def test_cast_to_server_uses_server_params(self): """Test kombu rpc.cast""" @@ -153,10 +154,10 @@ class RpcKombuTestCase(common.BaseRpcAMQPTestCase): def topic_send(_context, topic, msg): pass - MyConnection.pool = rpc_amqp.Pool(connection_cls=MyConnection) + MyConnection.pool = rpc_amqp.Pool(FLAGS, MyConnection) self.stubs.Set(impl_kombu, 'Connection', MyConnection) - impl_kombu.cast_to_server(ctxt, server_params, + impl_kombu.cast_to_server(FLAGS, ctxt, server_params, 'fake_topic', {'msg': 'fake'}) @test.skip_test("kombu memory transport seems buggy with fanout queues " @@ -192,7 +193,7 @@ class RpcKombuTestCase(common.BaseRpcAMQPTestCase): info = _raise_exc_stub(self.stubs, 2, self.rpc.DirectConsumer, '__init__', 'foo timeout foo') - conn = self.rpc.Connection() + conn = self.rpc.Connection(FLAGS) result = conn.declare_consumer(self.rpc.DirectConsumer, 'test_topic', None) @@ -206,7 +207,7 @@ class RpcKombuTestCase(common.BaseRpcAMQPTestCase): info = _raise_exc_stub(self.stubs, 1, self.rpc.DirectConsumer, '__init__', 'meow') - conn = self.rpc.Connection() + conn = self.rpc.Connection(FLAGS) conn.connection_errors = (MyException, ) result = conn.declare_consumer(self.rpc.DirectConsumer, @@ -220,7 +221,7 @@ class RpcKombuTestCase(common.BaseRpcAMQPTestCase): info = _raise_exc_stub(self.stubs, 2, self.rpc.DirectConsumer, '__init__', 'Socket closed', exc_class=IOError) - conn = self.rpc.Connection() + conn = self.rpc.Connection(FLAGS) result = conn.declare_consumer(self.rpc.DirectConsumer, 'test_topic', None) @@ -234,7 +235,7 @@ class RpcKombuTestCase(common.BaseRpcAMQPTestCase): info = _raise_exc_stub(self.stubs, 2, self.rpc.DirectPublisher, '__init__', 'foo timeout foo') - conn = self.rpc.Connection() + conn = self.rpc.Connection(FLAGS) conn.publisher_send(self.rpc.DirectPublisher, 'test_topic', 'msg') self.assertEqual(info['called'], 3) @@ -243,7 +244,7 @@ class RpcKombuTestCase(common.BaseRpcAMQPTestCase): info = _raise_exc_stub(self.stubs, 2, self.rpc.DirectPublisher, 'send', 'foo timeout foo') - conn = self.rpc.Connection() + conn = self.rpc.Connection(FLAGS) conn.publisher_send(self.rpc.DirectPublisher, 'test_topic', 'msg') self.assertEqual(info['called'], 3) @@ -256,7 +257,7 @@ class RpcKombuTestCase(common.BaseRpcAMQPTestCase): info = _raise_exc_stub(self.stubs, 1, self.rpc.DirectPublisher, '__init__', 'meow') - conn = self.rpc.Connection() + conn = self.rpc.Connection(FLAGS) conn.connection_errors = (MyException, ) conn.publisher_send(self.rpc.DirectPublisher, 'test_topic', 'msg') @@ -267,7 +268,7 @@ class RpcKombuTestCase(common.BaseRpcAMQPTestCase): info = _raise_exc_stub(self.stubs, 1, self.rpc.DirectPublisher, 'send', 'meow') - conn = self.rpc.Connection() + conn = self.rpc.Connection(FLAGS) conn.connection_errors = (MyException, ) conn.publisher_send(self.rpc.DirectPublisher, 'test_topic', 'msg') @@ -275,7 +276,7 @@ class RpcKombuTestCase(common.BaseRpcAMQPTestCase): self.assertEqual(info['called'], 2) def test_iterconsume_errors_will_reconnect(self): - conn = self.rpc.Connection() + conn = self.rpc.Connection(FLAGS) message = 'reconnect test message' self.received_message = None @@ -305,12 +306,13 @@ class RpcKombuTestCase(common.BaseRpcAMQPTestCase): value = "This is the exception message" self.assertRaises(NotImplementedError, self.rpc.call, + FLAGS, self.context, 'test', {"method": "fail", "args": {"value": value}}) try: - self.rpc.call(self.context, + self.rpc.call(FLAGS, self.context, 'test', {"method": "fail", "args": {"value": value}}) @@ -330,12 +332,13 @@ class RpcKombuTestCase(common.BaseRpcAMQPTestCase): value = "This is the exception message" self.assertRaises(exception.ConvertedException, self.rpc.call, + FLAGS, self.context, 'test', {"method": "fail_converted", "args": {"value": value}}) try: - self.rpc.call(self.context, + self.rpc.call(FLAGS, self.context, 'test', {"method": "fail_converted", "args": {"value": value}}) diff --git a/nova/tests/rpc/test_kombu_ssl.py b/nova/tests/rpc/test_kombu_ssl.py index 2a10835cc8e3..fb5a32eb88b8 100644 --- a/nova/tests/rpc/test_kombu_ssl.py +++ b/nova/tests/rpc/test_kombu_ssl.py @@ -19,6 +19,7 @@ Unit Tests for remote procedure calls using kombu + ssl """ +from nova import flags from nova import test from nova.rpc import impl_kombu @@ -28,11 +29,14 @@ SSL_CERT = "/tmp/cert.blah.blah" SSL_CA_CERT = "/tmp/cert.ca.blah.blah" SSL_KEYFILE = "/tmp/keyfile.blah.blah" +FLAGS = flags.FLAGS + class RpcKombuSslTestCase(test.TestCase): def setUp(self): super(RpcKombuSslTestCase, self).setUp() + impl_kombu.register_opts(FLAGS) self.flags(kombu_ssl_keyfile=SSL_KEYFILE, kombu_ssl_ca_certs=SSL_CA_CERT, kombu_ssl_certfile=SSL_CERT, @@ -41,7 +45,7 @@ class RpcKombuSslTestCase(test.TestCase): def test_ssl_on_extended(self): rpc = impl_kombu - conn = rpc.create_connection(True) + conn = rpc.create_connection(FLAGS, True) c = conn.connection #This might be kombu version dependent... #Since we are now peaking into the internals of kombu... diff --git a/nova/tests/rpc/test_qpid.py b/nova/tests/rpc/test_qpid.py index 616abb1c90e0..7959f3783ba0 100644 --- a/nova/tests/rpc/test_qpid.py +++ b/nova/tests/rpc/test_qpid.py @@ -23,6 +23,7 @@ Unit Tests for remote procedure calls using qpid import mox from nova import context +from nova import flags from nova import log as logging from nova.rpc import amqp as rpc_amqp from nova import test @@ -35,6 +36,7 @@ except ImportError: impl_qpid = None +FLAGS = flags.FLAGS LOG = logging.getLogger(__name__) @@ -64,6 +66,7 @@ class RpcQpidTestCase(test.TestCase): self.mock_receiver = None if qpid: + impl_qpid.register_opts(FLAGS) self.orig_connection = qpid.messaging.Connection self.orig_session = qpid.messaging.Session self.orig_sender = qpid.messaging.Sender @@ -98,7 +101,7 @@ class RpcQpidTestCase(test.TestCase): self.mox.ReplayAll() - connection = impl_qpid.create_connection() + connection = impl_qpid.create_connection(FLAGS) connection.close() def _test_create_consumer(self, fanout): @@ -130,7 +133,7 @@ class RpcQpidTestCase(test.TestCase): self.mox.ReplayAll() - connection = impl_qpid.create_connection() + connection = impl_qpid.create_connection(FLAGS) connection.create_consumer("impl_qpid_test", lambda *_x, **_y: None, fanout) @@ -176,11 +179,11 @@ class RpcQpidTestCase(test.TestCase): try: ctx = context.RequestContext("user", "project") - args = [ctx, "impl_qpid_test", + args = [FLAGS, ctx, "impl_qpid_test", {"method": "test_method", "args": {}}] if server_params: - args.insert(1, server_params) + args.insert(2, server_params) if fanout: method = impl_qpid.fanout_cast_to_server else: @@ -218,7 +221,7 @@ class RpcQpidTestCase(test.TestCase): server_params['hostname'] + ':' + str(server_params['port'])) - MyConnection.pool = rpc_amqp.Pool(connection_cls=MyConnection) + MyConnection.pool = rpc_amqp.Pool(FLAGS, MyConnection) self.stubs.Set(impl_qpid, 'Connection', MyConnection) @test.skip_if(qpid is None, "Test requires qpid") @@ -295,7 +298,7 @@ class RpcQpidTestCase(test.TestCase): else: method = impl_qpid.call - res = method(ctx, "impl_qpid_test", + res = method(FLAGS, ctx, "impl_qpid_test", {"method": "test_method", "args": {}}) if multi: