diff --git a/cfxdb/exporter.py b/cfxdb/exporter.py index 5cc53db8..b1c6ea9d 100644 --- a/cfxdb/exporter.py +++ b/cfxdb/exporter.py @@ -216,13 +216,25 @@ def export_database(self, filename, include_indexes=False, include_schemata=None click.style('{}.{}'.format(schema_name, table_name), fg='white', bold=True), click.style(str(cnt) + ' records', fg='yellow'))) - def import_database(self, filename, include_indexes=True): + def import_database(self, filename, include_indexes=False, include_schemata=None, exclude_tables=None): """ :param filename: :param include_indexes: :return: """ + if include_schemata is None: + schemata = sorted(self._schemata.keys()) + else: + assert type(include_schemata) == list + schemata = sorted(list(set(include_schemata).intersection(self._schemata.keys()))) + + if exclude_tables is None: + exclude_tables = set() + else: + assert type(exclude_tables) == list + exclude_tables = set(exclude_tables) + with open(filename, 'rb') as f: data = f.read() db_data = cbor2.loads(data) @@ -232,21 +244,24 @@ def import_database(self, filename, include_indexes=True): dbpath=self._dbpath, filename=filename, filesize=len(data))) with self._db.begin(write=True) as txn: - for schema_name in self._schemata: + for schema_name in schemata: for table_name in self._schema_tables[schema_name]: - table = self._schemata[schema_name].__dict__[table_name] - if not table.is_index() or include_indexes: - if schema_name in db_data and table_name in db_data[schema_name]: - cnt = 0 - for key, val in db_data[schema_name][table_name]: - key = table._deserialize_key(key) - val = table.parse(val) - table[txn, key] = val - cnt += 1 - if cnt: - print('{:.<52}: {}'.format( - click.style('{}.{}'.format(schema_name, table_name), - fg='white', - bold=True), click.style(str(cnt) + ' records', fg='yellow'))) - else: - print('No data to import for {}.{}!'.format(schema_name, table_name)) + fq_table_name = '{}.{}'.format(schema_name, table_name) + if fq_table_name not in exclude_tables: + table = self._schemata[schema_name].__dict__[table_name] + if not table.is_index() or include_indexes: + if schema_name in db_data and table_name in db_data[schema_name]: + cnt = 0 + for key, val in db_data[schema_name][table_name]: + key = table._deserialize_key(key) + val = table.parse(val) + table[txn, key] = val + cnt += 1 + if cnt: + print('{:.<52}: {}'.format( + click.style('{}.{}'.format(schema_name, table_name), + fg='white', + bold=True), click.style(str(cnt) + ' records', + fg='yellow'))) + else: + print('No data to import for {}.{}!'.format(schema_name, table_name))