diff --git a/migrate/versioning/schema.py b/migrate/versioning/schema.py index df94efa..3c1cd5f 100644 --- a/migrate/versioning/schema.py +++ b/migrate/versioning/schema.py @@ -1,6 +1,7 @@ from sqlalchemy import Table,Column,MetaData,String,Integer,create_engine from sqlalchemy import exceptions as sa_exceptions from migrate.versioning.repository import Repository +from migrate.versioning.util import loadModel from migrate.versioning.version import VerNum from migrate.versioning import exceptions, genmodel, schemadiff @@ -98,12 +99,7 @@ class ControlledSchema(object): if isinstance(repository, basestring): repository=Repository(repository) - if isinstance(model, basestring): # TODO: centralize this code? - # Assume model is of form "mod1.mod2.varname". - varname = model.split('.')[-1] - modules = '.'.join(model.split('.')[:-1]) - module = __import__(modules, globals(), {}, ['dummy-not-used'], -1) - model = getattr(module, varname) + model = loadModel(model) diff = schemadiff.getDiffOfModelAgainstDatabase(model, engine, excludeTables=[repository.version_table]) return diff @@ -122,12 +118,7 @@ class ControlledSchema(object): if isinstance(repository, basestring): repository=Repository(repository) - if isinstance(model, basestring): # TODO: centralize this code? - # Assume model is of form "mod1.mod2.varname". - varname = model.split('.')[-1] - modules = '.'.join(model.split('.')[:-1]) - module = __import__(modules, globals(), {}, ['dummy-not-used'], -1) - model = getattr(module, varname) + model = loadModel(model) diff = schemadiff.getDiffOfModelAgainstDatabase(model, engine, excludeTables=[repository.version_table]) return genmodel.ModelGenerator(diff).applyModel() diff --git a/migrate/versioning/script/py.py b/migrate/versioning/script/py.py index 6103827..7add4f4 100644 --- a/migrate/versioning/script/py.py +++ b/migrate/versioning/script/py.py @@ -4,7 +4,7 @@ from migrate.versioning import exceptions, genmodel, schemadiff from migrate.versioning.base import operations from migrate.versioning.template import template from migrate.versioning.script import base -from migrate.versioning.util import import_path +from migrate.versioning.util import import_path, loadModel class PythonScript(base.BaseScript): @classmethod @@ -28,12 +28,7 @@ class PythonScript(base.BaseScript): if isinstance(repository, basestring): from migrate.versioning.repository import Repository # oh dear, an import cycle! repository=Repository(repository) - if isinstance(model, basestring): # TODO: centralize this code? - # Assume model is of form "mod1.mod2.varname". - varname = model.split('.')[-1] - modules = '.'.join(model.split('.')[:-1]) - module = __import__(modules, globals(), {}, ['dummy-not-used'], -1) - model = getattr(module, varname) + model = loadModel(model) diff = schemadiff.getDiffOfModelAgainstDatabase(model, engine, excludeTables=[repository.version_table]) upgradeDecls, upgradeCommands = genmodel.ModelGenerator(diff).toUpgradePython() #downgradeCommands = genmodel.ModelGenerator(diff).toDowngradePython() diff --git a/migrate/versioning/util/__init__.py b/migrate/versioning/util/__init__.py index 9ad00a7..c1404ac 100644 --- a/migrate/versioning/util/__init__.py +++ b/migrate/versioning/util/__init__.py @@ -1,3 +1,14 @@ from keyedinstance import KeyedInstance from importpath import import_path +def loadModel(model): + ''' Import module and use module-level variable -- assume model is of form "mod1.mod2.varname". ''' + if isinstance(model, basestring): + varname = model.split('.')[-1] + modules = '.'.join(model.split('.')[:-1]) + module = __import__(modules, globals(), {}, ['dummy-not-used'], -1) + return getattr(module, varname) + else: + # Assume it's already loaded. + return model +