aboutsummaryrefslogtreecommitdiff
blob: ba75a529f9506e85f5bae360586de4f288d3a4d6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
# Copyright: 2005 Gentoo Foundation
# Author(s): Brian Harring (ferringb@gentoo.org)
# License: GPL2

import sys
from portage.cache import template, cache_errors
from portage.cache.template import reconstruct_eclasses

class SQLDatabase(template.database):
	"""template class for RDBM based caches
	
	This class is designed such that derivatives don't have to change much code, mostly constant strings.
	_BaseError must be an exception class that all Exceptions thrown from the derived RDBMS are derived
	from.

	SCHEMA_INSERT_CPV_INTO_PACKAGE should be modified dependant on the RDBMS, as should SCHEMA_PACKAGE_CREATE-
	basically you need to deal with creation of a unique pkgid.  If the dbapi2 rdbms class has a method of 
	recovering that id, then modify _insert_cpv to remove the extra select.

	Creation of a derived class involves supplying _initdb_con, and table_exists.
	Additionally, the default schemas may have to be modified.
	"""
	
	SCHEMA_PACKAGE_NAME	= "package_cache"
	SCHEMA_PACKAGE_CREATE 	= "CREATE TABLE %s (\
		pkgid INTEGER PRIMARY KEY, label VARCHAR(255), cpv VARCHAR(255), UNIQUE(label, cpv))" % SCHEMA_PACKAGE_NAME
	SCHEMA_PACKAGE_DROP	= "DROP TABLE %s" % SCHEMA_PACKAGE_NAME

	SCHEMA_VALUES_NAME	= "values_cache"
	SCHEMA_VALUES_CREATE	= "CREATE TABLE %s ( pkgid integer references %s (pkgid) on delete cascade, \
		key varchar(255), value text, UNIQUE(pkgid, key))" % (SCHEMA_VALUES_NAME, SCHEMA_PACKAGE_NAME)
	SCHEMA_VALUES_DROP	= "DROP TABLE %s" % SCHEMA_VALUES_NAME
	SCHEMA_INSERT_CPV_INTO_PACKAGE	= "INSERT INTO %s (label, cpv) VALUES(%%s, %%s)" % SCHEMA_PACKAGE_NAME

	_BaseError = ()
	_dbClass = None

	autocommits = False
#	cleanse_keys = True

	# boolean indicating if the derived RDBMS class supports replace syntax
	_supports_replace = False

	def __init__(self, location, label, auxdbkeys, *args, **config):
		"""initialize the instance.
		derived classes shouldn't need to override this"""

		super(SQLDatabase, self).__init__(location, label, auxdbkeys, *args, **config)

		config.setdefault("host","127.0.0.1")
		config.setdefault("autocommit", self.autocommits)
		self._initdb_con(config)

		self.label = self._sfilter(self.label)


	def _dbconnect(self, config):
		"""should be overridden if the derived class needs special parameters for initializing
		the db connection, or cursor"""
		self.db = self._dbClass(**config)
		self.con = self.db.cursor()


	def _initdb_con(self,config):
		"""ensure needed tables are in place.
		If the derived class needs a different set of table creation commands, overload the approriate
		SCHEMA_ attributes.  If it needs additional execution beyond, override"""

		self._dbconnect(config)
		if not self._table_exists(self.SCHEMA_PACKAGE_NAME):
			if self.readonly:
				raise cache_errors.ReadOnlyRestriction("table %s doesn't exist" % \
					self.SCHEMA_PACKAGE_NAME)
			try:
				self.con.execute(self.SCHEMA_PACKAGE_CREATE)
			except  self._BaseError as e:
				raise cache_errors.InitializationError(self.__class__, e)

		if not self._table_exists(self.SCHEMA_VALUES_NAME):
			if self.readonly:
				raise cache_errors.ReadOnlyRestriction("table %s doesn't exist" % \
					self.SCHEMA_VALUES_NAME)
			try:
				self.con.execute(self.SCHEMA_VALUES_CREATE)
			except	self._BaseError as e:
				raise cache_errors.InitializationError(self.__class__, e)


	def _table_exists(self, tbl):
		"""return true if a table exists
		derived classes must override this"""
		raise NotImplementedError


	def _sfilter(self, s):
		"""meta escaping, returns quoted string for use in sql statements"""
		return "\"%s\"" % s.replace("\\","\\\\").replace("\"","\\\"")


	def _getitem(self, cpv):
		try:
			self.con.execute("SELECT key, value FROM %s NATURAL JOIN %s "
			"WHERE label=%s AND cpv=%s" % (self.SCHEMA_PACKAGE_NAME, self.SCHEMA_VALUES_NAME,
			self.label, self._sfilter(cpv)))
		except self._BaseError as e:
			raise cache_errors.CacheCorruption(self, cpv, e)

		rows = self.con.fetchall()

		if len(rows) == 0:
			raise KeyError(cpv)

		vals = dict([(k,"") for k in self._known_keys])
		vals.update(dict(rows))
		return vals


	def _delitem(self, cpv):
		"""delete a cpv cache entry
		derived RDBM classes for this *must* either support cascaded deletes, or 
		override this method"""
		try:
			try:	
				self.con.execute("DELETE FROM %s WHERE label=%s AND cpv=%s" % \
				(self.SCHEMA_PACKAGE_NAME, self.label, self._sfilter(cpv)))
				if self.autocommits:
					self.commit()
			except self._BaseError as e:
				raise cache_errors.CacheCorruption(self, cpv, e)
			if self.con.rowcount <= 0:
				raise KeyError(cpv)
		except SystemExit:
			raise
		except Exception:
			if not self.autocommits:
				self.db.rollback()
				# yes, this can roll back a lot more then just the delete.  deal.
			raise

	def __del__(self):
		# just to be safe.
		if "db" in self.__dict__ and self.db != None:
			self.commit()
			self.db.close()

	def _setitem(self, cpv, values):

		try:
			# insert.
			try:
				pkgid = self._insert_cpv(cpv)
			except self._BaseError as e:
				raise cache_errors.CacheCorruption(cpv, e)

			# __getitem__ fills out missing values, 
			# so we store only what's handed to us and is a known key
			db_values = []
			for key in self._known_keys:
				if key in values and values[key]:
					db_values.append({"key":key, "value":values[key]})

			if len(db_values) > 0:
				try:
					self.con.executemany("INSERT INTO %s (pkgid, key, value) VALUES(\"%s\", %%(key)s, %%(value)s)" % \
					(self.SCHEMA_VALUES_NAME, str(pkgid)), db_values)
				except self._BaseError as e:
					raise cache_errors.CacheCorruption(cpv, e)
			if self.autocommits:
				self.commit()

		except SystemExit:
			raise
		except Exception:
			if not self.autocommits:
				try:
					self.db.rollback()
				except self._BaseError:
					pass
			raise


	def _insert_cpv(self, cpv):
		"""uses SCHEMA_INSERT_CPV_INTO_PACKAGE, which must be overloaded if the table definition
		doesn't support auto-increment columns for pkgid.
		returns the cpvs new pkgid
		note this doesn't commit the transaction.  The caller is expected to."""
		
		cpv = self._sfilter(cpv)
		if self._supports_replace:
			query_str = self.SCHEMA_INSERT_CPV_INTO_PACKAGE.replace("INSERT","REPLACE",1)
		else:
			# just delete it.
			try:
				del self[cpv]
			except (cache_errors.CacheCorruption, KeyError):
				pass
			query_str = self.SCHEMA_INSERT_CPV_INTO_PACKAGE
		try:
			self.con.execute(query_str % (self.label, cpv))
		except self._BaseError:
			self.db.rollback()
			raise
		self.con.execute("SELECT pkgid FROM %s WHERE label=%s AND cpv=%s" % \
			(self.SCHEMA_PACKAGE_NAME, self.label, cpv))
			
		if self.con.rowcount != 1:
			raise cache_error.CacheCorruption(cpv, "Tried to insert the cpv, but found "
				" %i matches upon the following select!" % len(rows))
		return self.con.fetchone()[0]


	def __contains__(self, cpv):
		if not self.autocommits:
			try:
				self.commit()
			except self._BaseError as e:
				raise cache_errors.GeneralCacheCorruption(e)

		try:
			self.con.execute("SELECT cpv FROM %s WHERE label=%s AND cpv=%s" % \
				(self.SCHEMA_PACKAGE_NAME, self.label, self._sfilter(cpv)))
		except self._BaseError as e:
			raise cache_errors.GeneralCacheCorruption(e)
		return self.con.rowcount > 0


	def __iter__(self):
		if not self.autocommits:
			try:
				self.commit()
			except self._BaseError as e:
				raise cache_errors.GeneralCacheCorruption(e)

		try:
			self.con.execute("SELECT cpv FROM %s WHERE label=%s" % 
				(self.SCHEMA_PACKAGE_NAME, self.label))
		except self._BaseError as e:
			raise cache_errors.GeneralCacheCorruption(e)
#		return [ row[0] for row in self.con.fetchall() ]
		for x in self.con.fetchall():
			yield x[0]

	def iteritems(self):
		try:
			self.con.execute("SELECT cpv, key, value FROM %s NATURAL JOIN %s "
			"WHERE label=%s" % (self.SCHEMA_PACKAGE_NAME, self.SCHEMA_VALUES_NAME,
			self.label))
		except self._BaseError as e:
			raise cache_errors.CacheCorruption(self, cpv, e)
		
		oldcpv = None
		l = []
		for x, y, v in self.con.fetchall():
			if oldcpv != x:
				if oldcpv != None:
					d = dict(l)
					if "_eclasses_" in d:
						d["_eclasses_"] = reconstruct_eclasses(oldcpv, d["_eclasses_"])
					else:
						d["_eclasses_"] = {}
					yield cpv, d
				l.clear()
				oldcpv = x
			l.append((y,v))
		if oldcpv != None:
			d = dict(l)
			if "_eclasses_" in d:
				d["_eclasses_"] = reconstruct_eclasses(oldcpv, d["_eclasses_"])
			else:
				d["_eclasses_"] = {}
			yield cpv, d			

	def commit(self):
		self.db.commit()

	def get_matches(self,match_dict):
		query_list = []
		for k,v in match_dict.items():
			if k not in self._known_keys:
				raise cache_errors.InvalidRestriction(k, v, "key isn't known to this cache instance")
			v = v.replace("%","\\%")
			v = v.replace(".*","%")
			query_list.append("(key=%s AND value LIKE %s)" % (self._sfilter(k), self._sfilter(v)))

		if len(query_list):
			query = " AND "+" AND ".join(query_list)
		else:
			query = ''

		print("query = SELECT cpv from package_cache natural join values_cache WHERE label=%s %s" % (self.label, query))
		try:
			self.con.execute("SELECT cpv from package_cache natural join values_cache WHERE label=%s %s" % \
				(self.label, query))
		except self._BaseError as e:
			raise cache_errors.GeneralCacheCorruption(e)

		return [ row[0] for row in self.con.fetchall() ]

	items = iteritems
	keys = __iter__