diff options
author | takeshi_hoshina <takeshi_hoshina@mail.toyota.co.jp> | 2020-11-02 11:07:33 +0900 |
---|---|---|
committer | takeshi_hoshina <takeshi_hoshina@mail.toyota.co.jp> | 2020-11-02 11:07:33 +0900 |
commit | 1c7d6584a7811b7785ae5c1e378f14b5ba0971cf (patch) | |
tree | cd70a267a5ef105ba32f200aa088e281fbd85747 /external/poky/bitbake/lib/hashserv | |
parent | 4204309872da5cb401cbb2729d9e2d4869a87f42 (diff) |
basesystem-jjsandbox/ToshikazuOhiwa/master-jj
recipes
Diffstat (limited to 'external/poky/bitbake/lib/hashserv')
-rw-r--r-- | external/poky/bitbake/lib/hashserv/__init__.py | 115 | ||||
-rw-r--r-- | external/poky/bitbake/lib/hashserv/client.py | 191 | ||||
-rw-r--r-- | external/poky/bitbake/lib/hashserv/server.py | 489 | ||||
-rw-r--r-- | external/poky/bitbake/lib/hashserv/tests.py | 165 |
4 files changed, 960 insertions, 0 deletions
diff --git a/external/poky/bitbake/lib/hashserv/__init__.py b/external/poky/bitbake/lib/hashserv/__init__.py new file mode 100644 index 00000000..f95e8f43 --- /dev/null +++ b/external/poky/bitbake/lib/hashserv/__init__.py @@ -0,0 +1,115 @@ +# Copyright (C) 2018-2019 Garmin Ltd. +# +# SPDX-License-Identifier: GPL-2.0-only +# + +from contextlib import closing +import re +import sqlite3 +import itertools +import json + +UNIX_PREFIX = "unix://" + +ADDR_TYPE_UNIX = 0 +ADDR_TYPE_TCP = 1 + +# The Python async server defaults to a 64K receive buffer, so we hardcode our +# maximum chunk size. It would be better if the client and server reported to +# each other what the maximum chunk sizes were, but that will slow down the +# connection setup with a round trip delay so I'd rather not do that unless it +# is necessary +DEFAULT_MAX_CHUNK = 32 * 1024 + +def setup_database(database, sync=True): + db = sqlite3.connect(database) + db.row_factory = sqlite3.Row + + with closing(db.cursor()) as cursor: + cursor.execute(''' + CREATE TABLE IF NOT EXISTS tasks_v2 ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + method TEXT NOT NULL, + outhash TEXT NOT NULL, + taskhash TEXT NOT NULL, + unihash TEXT NOT NULL, + created DATETIME, + + -- Optional fields + owner TEXT, + PN TEXT, + PV TEXT, + PR TEXT, + task TEXT, + outhash_siginfo TEXT, + + UNIQUE(method, outhash, taskhash) + ) + ''') + cursor.execute('PRAGMA journal_mode = WAL') + cursor.execute('PRAGMA synchronous = %s' % ('NORMAL' if sync else 'OFF')) + + # Drop old indexes + cursor.execute('DROP INDEX IF EXISTS taskhash_lookup') + cursor.execute('DROP INDEX IF EXISTS outhash_lookup') + + # Create new indexes + cursor.execute('CREATE INDEX IF NOT EXISTS taskhash_lookup_v2 ON tasks_v2 (method, taskhash, created)') + cursor.execute('CREATE INDEX IF NOT EXISTS outhash_lookup_v2 ON tasks_v2 (method, outhash)') + + return db + + +def parse_address(addr): + if addr.startswith(UNIX_PREFIX): + return (ADDR_TYPE_UNIX, (addr[len(UNIX_PREFIX):],)) + else: + m = re.match(r'\[(?P<host>[^\]]*)\]:(?P<port>\d+)$', addr) + if m is not None: + host = m.group('host') + port = m.group('port') + else: + host, port = addr.split(':') + + return (ADDR_TYPE_TCP, (host, int(port))) + + +def chunkify(msg, max_chunk): + if len(msg) < max_chunk - 1: + yield ''.join((msg, "\n")) + else: + yield ''.join((json.dumps({ + 'chunk-stream': None + }), "\n")) + + args = [iter(msg)] * (max_chunk - 1) + for m in map(''.join, itertools.zip_longest(*args, fillvalue='')): + yield ''.join(itertools.chain(m, "\n")) + yield "\n" + + +def create_server(addr, dbname, *, sync=True): + from . import server + db = setup_database(dbname, sync=sync) + s = server.Server(db) + + (typ, a) = parse_address(addr) + if typ == ADDR_TYPE_UNIX: + s.start_unix_server(*a) + else: + s.start_tcp_server(*a) + + return s + + +def create_client(addr): + from . import client + c = client.Client() + + (typ, a) = parse_address(addr) + if typ == ADDR_TYPE_UNIX: + c.connect_unix(*a) + else: + c.connect_tcp(*a) + + return c diff --git a/external/poky/bitbake/lib/hashserv/client.py b/external/poky/bitbake/lib/hashserv/client.py new file mode 100644 index 00000000..a29af836 --- /dev/null +++ b/external/poky/bitbake/lib/hashserv/client.py @@ -0,0 +1,191 @@ +# Copyright (C) 2019 Garmin Ltd. +# +# SPDX-License-Identifier: GPL-2.0-only +# + +import json +import logging +import socket +import os +from . import chunkify, DEFAULT_MAX_CHUNK + + +logger = logging.getLogger('hashserv.client') + + +class HashConnectionError(Exception): + pass + + +class Client(object): + MODE_NORMAL = 0 + MODE_GET_STREAM = 1 + + def __init__(self): + self._socket = None + self.reader = None + self.writer = None + self.mode = self.MODE_NORMAL + self.max_chunk = DEFAULT_MAX_CHUNK + + def connect_tcp(self, address, port): + def connect_sock(): + s = socket.create_connection((address, port)) + + s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1) + s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1) + s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) + return s + + self._connect_sock = connect_sock + + def connect_unix(self, path): + def connect_sock(): + s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + # AF_UNIX has path length issues so chdir here to workaround + cwd = os.getcwd() + try: + os.chdir(os.path.dirname(path)) + s.connect(os.path.basename(path)) + finally: + os.chdir(cwd) + return s + + self._connect_sock = connect_sock + + def connect(self): + if self._socket is None: + self._socket = self._connect_sock() + + self.reader = self._socket.makefile('r', encoding='utf-8') + self.writer = self._socket.makefile('w', encoding='utf-8') + + self.writer.write('OEHASHEQUIV 1.1\n\n') + self.writer.flush() + + # Restore mode if the socket is being re-created + cur_mode = self.mode + self.mode = self.MODE_NORMAL + self._set_mode(cur_mode) + + return self._socket + + def close(self): + if self._socket is not None: + self._socket.close() + self._socket = None + self.reader = None + self.writer = None + + def _send_wrapper(self, proc): + count = 0 + while True: + try: + self.connect() + return proc() + except (OSError, HashConnectionError, json.JSONDecodeError, UnicodeDecodeError) as e: + logger.warning('Error talking to server: %s' % e) + if count >= 3: + if not isinstance(e, HashConnectionError): + raise HashConnectionError(str(e)) + raise e + self.close() + count += 1 + + def send_message(self, msg): + def get_line(): + line = self.reader.readline() + if not line: + raise HashConnectionError('Connection closed') + + if not line.endswith('\n'): + raise HashConnectionError('Bad message %r' % message) + + return line + + def proc(): + for c in chunkify(json.dumps(msg), self.max_chunk): + self.writer.write(c) + self.writer.flush() + + l = get_line() + + m = json.loads(l) + if 'chunk-stream' in m: + lines = [] + while True: + l = get_line().rstrip('\n') + if not l: + break + lines.append(l) + + m = json.loads(''.join(lines)) + + return m + + return self._send_wrapper(proc) + + def send_stream(self, msg): + def proc(): + self.writer.write("%s\n" % msg) + self.writer.flush() + l = self.reader.readline() + if not l: + raise HashConnectionError('Connection closed') + return l.rstrip() + + return self._send_wrapper(proc) + + def _set_mode(self, new_mode): + if new_mode == self.MODE_NORMAL and self.mode == self.MODE_GET_STREAM: + r = self.send_stream('END') + if r != 'ok': + raise HashConnectionError('Bad response from server %r' % r) + elif new_mode == self.MODE_GET_STREAM and self.mode == self.MODE_NORMAL: + r = self.send_message({'get-stream': None}) + if r != 'ok': + raise HashConnectionError('Bad response from server %r' % r) + elif new_mode != self.mode: + raise Exception('Undefined mode transition %r -> %r' % (self.mode, new_mode)) + + self.mode = new_mode + + def get_unihash(self, method, taskhash): + self._set_mode(self.MODE_GET_STREAM) + r = self.send_stream('%s %s' % (method, taskhash)) + if not r: + return None + return r + + def report_unihash(self, taskhash, method, outhash, unihash, extra={}): + self._set_mode(self.MODE_NORMAL) + m = extra.copy() + m['taskhash'] = taskhash + m['method'] = method + m['outhash'] = outhash + m['unihash'] = unihash + return self.send_message({'report': m}) + + def report_unihash_equiv(self, taskhash, method, unihash, extra={}): + self._set_mode(self.MODE_NORMAL) + m = extra.copy() + m['taskhash'] = taskhash + m['method'] = method + m['unihash'] = unihash + return self.send_message({'report-equiv': m}) + + def get_taskhash(self, method, taskhash, all_properties=False): + self._set_mode(self.MODE_NORMAL) + return self.send_message({'get': { + 'taskhash': taskhash, + 'method': method, + 'all': all_properties + }}) + + def get_stats(self): + self._set_mode(self.MODE_NORMAL) + return self.send_message({'get-stats': None}) + + def reset_stats(self): + self._set_mode(self.MODE_NORMAL) + return self.send_message({'reset-stats': None}) diff --git a/external/poky/bitbake/lib/hashserv/server.py b/external/poky/bitbake/lib/hashserv/server.py new file mode 100644 index 00000000..81050715 --- /dev/null +++ b/external/poky/bitbake/lib/hashserv/server.py @@ -0,0 +1,489 @@ +# Copyright (C) 2019 Garmin Ltd. +# +# SPDX-License-Identifier: GPL-2.0-only +# + +from contextlib import closing +from datetime import datetime +import asyncio +import json +import logging +import math +import os +import signal +import socket +import time +from . import chunkify, DEFAULT_MAX_CHUNK + +logger = logging.getLogger('hashserv.server') + + +class Measurement(object): + def __init__(self, sample): + self.sample = sample + + def start(self): + self.start_time = time.perf_counter() + + def end(self): + self.sample.add(time.perf_counter() - self.start_time) + + def __enter__(self): + self.start() + return self + + def __exit__(self, *args, **kwargs): + self.end() + + +class Sample(object): + def __init__(self, stats): + self.stats = stats + self.num_samples = 0 + self.elapsed = 0 + + def measure(self): + return Measurement(self) + + def __enter__(self): + return self + + def __exit__(self, *args, **kwargs): + self.end() + + def add(self, elapsed): + self.num_samples += 1 + self.elapsed += elapsed + + def end(self): + if self.num_samples: + self.stats.add(self.elapsed) + self.num_samples = 0 + self.elapsed = 0 + + +class Stats(object): + def __init__(self): + self.reset() + + def reset(self): + self.num = 0 + self.total_time = 0 + self.max_time = 0 + self.m = 0 + self.s = 0 + self.current_elapsed = None + + def add(self, elapsed): + self.num += 1 + if self.num == 1: + self.m = elapsed + self.s = 0 + else: + last_m = self.m + self.m = last_m + (elapsed - last_m) / self.num + self.s = self.s + (elapsed - last_m) * (elapsed - self.m) + + self.total_time += elapsed + + if self.max_time < elapsed: + self.max_time = elapsed + + def start_sample(self): + return Sample(self) + + @property + def average(self): + if self.num == 0: + return 0 + return self.total_time / self.num + + @property + def stdev(self): + if self.num <= 1: + return 0 + return math.sqrt(self.s / (self.num - 1)) + + def todict(self): + return {k: getattr(self, k) for k in ('num', 'total_time', 'max_time', 'average', 'stdev')} + + +class ClientError(Exception): + pass + +class ServerClient(object): + FAST_QUERY = 'SELECT taskhash, method, unihash FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1' + ALL_QUERY = 'SELECT * FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1' + + def __init__(self, reader, writer, db, request_stats): + self.reader = reader + self.writer = writer + self.db = db + self.request_stats = request_stats + self.max_chunk = DEFAULT_MAX_CHUNK + + self.handlers = { + 'get': self.handle_get, + 'report': self.handle_report, + 'report-equiv': self.handle_equivreport, + 'get-stream': self.handle_get_stream, + 'get-stats': self.handle_get_stats, + 'reset-stats': self.handle_reset_stats, + 'chunk-stream': self.handle_chunk, + } + + async def process_requests(self): + try: + self.addr = self.writer.get_extra_info('peername') + logger.debug('Client %r connected' % (self.addr,)) + + # Read protocol and version + protocol = await self.reader.readline() + if protocol is None: + return + + (proto_name, proto_version) = protocol.decode('utf-8').rstrip().split() + if proto_name != 'OEHASHEQUIV': + return + + proto_version = tuple(int(v) for v in proto_version.split('.')) + if proto_version < (1, 0) or proto_version > (1, 1): + return + + # Read headers. Currently, no headers are implemented, so look for + # an empty line to signal the end of the headers + while True: + line = await self.reader.readline() + if line is None: + return + + line = line.decode('utf-8').rstrip() + if not line: + break + + # Handle messages + while True: + d = await self.read_message() + if d is None: + break + await self.dispatch_message(d) + await self.writer.drain() + except ClientError as e: + logger.error(str(e)) + finally: + self.writer.close() + + async def dispatch_message(self, msg): + for k in self.handlers.keys(): + if k in msg: + logger.debug('Handling %s' % k) + if 'stream' in k: + await self.handlers[k](msg[k]) + else: + with self.request_stats.start_sample() as self.request_sample, \ + self.request_sample.measure(): + await self.handlers[k](msg[k]) + return + + raise ClientError("Unrecognized command %r" % msg) + + def write_message(self, msg): + for c in chunkify(json.dumps(msg), self.max_chunk): + self.writer.write(c.encode('utf-8')) + + async def read_message(self): + l = await self.reader.readline() + if not l: + return None + + try: + message = l.decode('utf-8') + + if not message.endswith('\n'): + return None + + return json.loads(message) + except (json.JSONDecodeError, UnicodeDecodeError) as e: + logger.error('Bad message from client: %r' % message) + raise e + + async def handle_chunk(self, request): + lines = [] + try: + while True: + l = await self.reader.readline() + l = l.rstrip(b"\n").decode("utf-8") + if not l: + break + lines.append(l) + + msg = json.loads(''.join(lines)) + except (json.JSONDecodeError, UnicodeDecodeError) as e: + logger.error('Bad message from client: %r' % message) + raise e + + if 'chunk-stream' in msg: + raise ClientError("Nested chunks are not allowed") + + await self.dispatch_message(msg) + + async def handle_get(self, request): + method = request['method'] + taskhash = request['taskhash'] + + if request.get('all', False): + row = self.query_equivalent(method, taskhash, self.ALL_QUERY) + else: + row = self.query_equivalent(method, taskhash, self.FAST_QUERY) + + if row is not None: + logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash'])) + d = {k: row[k] for k in row.keys()} + + self.write_message(d) + else: + self.write_message(None) + + async def handle_get_stream(self, request): + self.write_message('ok') + + while True: + l = await self.reader.readline() + if not l: + return + + try: + # This inner loop is very sensitive and must be as fast as + # possible (which is why the request sample is handled manually + # instead of using 'with', and also why logging statements are + # commented out. + self.request_sample = self.request_stats.start_sample() + request_measure = self.request_sample.measure() + request_measure.start() + + l = l.decode('utf-8').rstrip() + if l == 'END': + self.writer.write('ok\n'.encode('utf-8')) + return + + (method, taskhash) = l.split() + #logger.debug('Looking up %s %s' % (method, taskhash)) + row = self.query_equivalent(method, taskhash, self.FAST_QUERY) + if row is not None: + msg = ('%s\n' % row['unihash']).encode('utf-8') + #logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash'])) + else: + msg = '\n'.encode('utf-8') + + self.writer.write(msg) + finally: + request_measure.end() + self.request_sample.end() + + await self.writer.drain() + + async def handle_report(self, data): + with closing(self.db.cursor()) as cursor: + cursor.execute(''' + -- Find tasks with a matching outhash (that is, tasks that + -- are equivalent) + SELECT taskhash, method, unihash FROM tasks_v2 WHERE method=:method AND outhash=:outhash + + -- If there is an exact match on the taskhash, return it. + -- Otherwise return the oldest matching outhash of any + -- taskhash + ORDER BY CASE WHEN taskhash=:taskhash THEN 1 ELSE 2 END, + created ASC + + -- Only return one row + LIMIT 1 + ''', {k: data[k] for k in ('method', 'outhash', 'taskhash')}) + + row = cursor.fetchone() + + # If no matching outhash was found, or one *was* found but it + # wasn't an exact match on the taskhash, a new entry for this + # taskhash should be added + if row is None or row['taskhash'] != data['taskhash']: + # If a row matching the outhash was found, the unihash for + # the new taskhash should be the same as that one. + # Otherwise the caller provided unihash is used. + unihash = data['unihash'] + if row is not None: + unihash = row['unihash'] + + insert_data = { + 'method': data['method'], + 'outhash': data['outhash'], + 'taskhash': data['taskhash'], + 'unihash': unihash, + 'created': datetime.now() + } + + for k in ('owner', 'PN', 'PV', 'PR', 'task', 'outhash_siginfo'): + if k in data: + insert_data[k] = data[k] + + cursor.execute('''INSERT INTO tasks_v2 (%s) VALUES (%s)''' % ( + ', '.join(sorted(insert_data.keys())), + ', '.join(':' + k for k in sorted(insert_data.keys()))), + insert_data) + + self.db.commit() + + logger.info('Adding taskhash %s with unihash %s', + data['taskhash'], unihash) + + d = { + 'taskhash': data['taskhash'], + 'method': data['method'], + 'unihash': unihash + } + else: + d = {k: row[k] for k in ('taskhash', 'method', 'unihash')} + + self.write_message(d) + + async def handle_equivreport(self, data): + with closing(self.db.cursor()) as cursor: + insert_data = { + 'method': data['method'], + 'outhash': "", + 'taskhash': data['taskhash'], + 'unihash': data['unihash'], + 'created': datetime.now() + } + + for k in ('owner', 'PN', 'PV', 'PR', 'task', 'outhash_siginfo'): + if k in data: + insert_data[k] = data[k] + + cursor.execute('''INSERT OR IGNORE INTO tasks_v2 (%s) VALUES (%s)''' % ( + ', '.join(sorted(insert_data.keys())), + ', '.join(':' + k for k in sorted(insert_data.keys()))), + insert_data) + + self.db.commit() + + # Fetch the unihash that will be reported for the taskhash. If the + # unihash matches, it means this row was inserted (or the mapping + # was already valid) + row = self.query_equivalent(data['method'], data['taskhash'], self.FAST_QUERY) + + if row['unihash'] == data['unihash']: + logger.info('Adding taskhash equivalence for %s with unihash %s', + data['taskhash'], row['unihash']) + + d = {k: row[k] for k in ('taskhash', 'method', 'unihash')} + + self.write_message(d) + + + async def handle_get_stats(self, request): + d = { + 'requests': self.request_stats.todict(), + } + + self.write_message(d) + + async def handle_reset_stats(self, request): + d = { + 'requests': self.request_stats.todict(), + } + + self.request_stats.reset() + self.write_message(d) + + def query_equivalent(self, method, taskhash, query): + # This is part of the inner loop and must be as fast as possible + try: + cursor = self.db.cursor() + cursor.execute(query, {'method': method, 'taskhash': taskhash}) + return cursor.fetchone() + except: + cursor.close() + + +class Server(object): + def __init__(self, db, loop=None): + self.request_stats = Stats() + self.db = db + + if loop is None: + self.loop = asyncio.new_event_loop() + self.close_loop = True + else: + self.loop = loop + self.close_loop = False + + self._cleanup_socket = None + + def start_tcp_server(self, host, port): + self.server = self.loop.run_until_complete( + asyncio.start_server(self.handle_client, host, port, loop=self.loop) + ) + + for s in self.server.sockets: + logger.info('Listening on %r' % (s.getsockname(),)) + # Newer python does this automatically. Do it manually here for + # maximum compatibility + s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1) + s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1) + + name = self.server.sockets[0].getsockname() + if self.server.sockets[0].family == socket.AF_INET6: + self.address = "[%s]:%d" % (name[0], name[1]) + else: + self.address = "%s:%d" % (name[0], name[1]) + + def start_unix_server(self, path): + def cleanup(): + os.unlink(path) + + cwd = os.getcwd() + try: + # Work around path length limits in AF_UNIX + os.chdir(os.path.dirname(path)) + self.server = self.loop.run_until_complete( + asyncio.start_unix_server(self.handle_client, os.path.basename(path), loop=self.loop) + ) + finally: + os.chdir(cwd) + + logger.info('Listening on %r' % path) + + self._cleanup_socket = cleanup + self.address = "unix://%s" % os.path.abspath(path) + + async def handle_client(self, reader, writer): + # writer.transport.set_write_buffer_limits(0) + try: + client = ServerClient(reader, writer, self.db, self.request_stats) + await client.process_requests() + except Exception as e: + import traceback + logger.error('Error from client: %s' % str(e), exc_info=True) + traceback.print_exc() + writer.close() + logger.info('Client disconnected') + + def serve_forever(self): + def signal_handler(): + self.loop.stop() + + self.loop.add_signal_handler(signal.SIGTERM, signal_handler) + + try: + self.loop.run_forever() + except KeyboardInterrupt: + pass + + self.server.close() + self.loop.run_until_complete(self.server.wait_closed()) + logger.info('Server shutting down') + + if self.close_loop: + self.loop.close() + + if self._cleanup_socket is not None: + self._cleanup_socket() diff --git a/external/poky/bitbake/lib/hashserv/tests.py b/external/poky/bitbake/lib/hashserv/tests.py new file mode 100644 index 00000000..6e862950 --- /dev/null +++ b/external/poky/bitbake/lib/hashserv/tests.py @@ -0,0 +1,165 @@ +#! /usr/bin/env python3 +# +# Copyright (C) 2018-2019 Garmin Ltd. +# +# SPDX-License-Identifier: GPL-2.0-only +# + +from . import create_server, create_client +import hashlib +import logging +import multiprocessing +import sys +import tempfile +import threading +import unittest + + +class TestHashEquivalenceServer(object): + METHOD = 'TestMethod' + + def _run_server(self): + # logging.basicConfig(level=logging.DEBUG, filename='bbhashserv.log', filemode='w', + # format='%(levelname)s %(filename)s:%(lineno)d %(message)s') + self.server.serve_forever() + + def setUp(self): + if sys.version_info < (3, 5, 0): + self.skipTest('Python 3.5 or later required') + + self.temp_dir = tempfile.TemporaryDirectory(prefix='bb-hashserv') + self.dbfile = os.path.join(self.temp_dir.name, 'db.sqlite') + + self.server = create_server(self.get_server_addr(), self.dbfile) + self.server_thread = multiprocessing.Process(target=self._run_server) + self.server_thread.start() + self.client = create_client(self.server.address) + + def tearDown(self): + # Shutdown server + s = getattr(self, 'server', None) + if s is not None: + self.server_thread.terminate() + self.server_thread.join() + self.client.close() + self.temp_dir.cleanup() + + def test_create_hash(self): + # Simple test that hashes can be created + taskhash = '35788efcb8dfb0a02659d81cf2bfd695fb30faf9' + outhash = '2765d4a5884be49b28601445c2760c5f21e7e5c0ee2b7e3fce98fd7e5970796f' + unihash = 'f46d3fbb439bd9b921095da657a4de906510d2cd' + + result = self.client.get_unihash(self.METHOD, taskhash) + self.assertIsNone(result, msg='Found unexpected task, %r' % result) + + result = self.client.report_unihash(taskhash, self.METHOD, outhash, unihash) + self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash') + + def test_create_equivalent(self): + # Tests that a second reported task with the same outhash will be + # assigned the same unihash + taskhash = '53b8dce672cb6d0c73170be43f540460bfc347b4' + outhash = '5a9cb1649625f0bf41fc7791b635cd9c2d7118c7f021ba87dcd03f72b67ce7a8' + unihash = 'f37918cc02eb5a520b1aff86faacbc0a38124646' + + result = self.client.report_unihash(taskhash, self.METHOD, outhash, unihash) + self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash') + + # Report a different task with the same outhash. The returned unihash + # should match the first task + taskhash2 = '3bf6f1e89d26205aec90da04854fbdbf73afe6b4' + unihash2 = 'af36b199320e611fbb16f1f277d3ee1d619ca58b' + result = self.client.report_unihash(taskhash2, self.METHOD, outhash, unihash2) + self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash') + + def test_duplicate_taskhash(self): + # Tests that duplicate reports of the same taskhash with different + # outhash & unihash always return the unihash from the first reported + # taskhash + taskhash = '8aa96fcffb5831b3c2c0cb75f0431e3f8b20554a' + outhash = 'afe240a439959ce86f5e322f8c208e1fedefea9e813f2140c81af866cc9edf7e' + unihash = '218e57509998197d570e2c98512d0105985dffc9' + self.client.report_unihash(taskhash, self.METHOD, outhash, unihash) + + result = self.client.get_unihash(self.METHOD, taskhash) + self.assertEqual(result, unihash) + + outhash2 = '0904a7fe3dc712d9fd8a74a616ddca2a825a8ee97adf0bd3fc86082c7639914d' + unihash2 = 'ae9a7d252735f0dafcdb10e2e02561ca3a47314c' + self.client.report_unihash(taskhash, self.METHOD, outhash2, unihash2) + + result = self.client.get_unihash(self.METHOD, taskhash) + self.assertEqual(result, unihash) + + outhash3 = '77623a549b5b1a31e3732dfa8fe61d7ce5d44b3370f253c5360e136b852967b4' + unihash3 = '9217a7d6398518e5dc002ed58f2cbbbc78696603' + self.client.report_unihash(taskhash, self.METHOD, outhash3, unihash3) + + result = self.client.get_unihash(self.METHOD, taskhash) + self.assertEqual(result, unihash) + + def test_huge_message(self): + # Simple test that hashes can be created + taskhash = 'c665584ee6817aa99edfc77a44dd853828279370' + outhash = '3c979c3db45c569f51ab7626a4651074be3a9d11a84b1db076f5b14f7d39db44' + unihash = '90e9bc1d1f094c51824adca7f8ea79a048d68824' + + result = self.client.get_unihash(self.METHOD, taskhash) + self.assertIsNone(result, msg='Found unexpected task, %r' % result) + + siginfo = "0" * (self.client.max_chunk * 4) + + result = self.client.report_unihash(taskhash, self.METHOD, outhash, unihash, { + 'outhash_siginfo': siginfo + }) + self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash') + + result = self.client.get_taskhash(self.METHOD, taskhash, True) + self.assertEqual(result['taskhash'], taskhash) + self.assertEqual(result['unihash'], unihash) + self.assertEqual(result['method'], self.METHOD) + self.assertEqual(result['outhash'], outhash) + self.assertEqual(result['outhash_siginfo'], siginfo) + + def test_stress(self): + def query_server(failures): + client = Client(self.server.address) + try: + for i in range(1000): + taskhash = hashlib.sha256() + taskhash.update(str(i).encode('utf-8')) + taskhash = taskhash.hexdigest() + result = client.get_unihash(self.METHOD, taskhash) + if result != taskhash: + failures.append("taskhash mismatch: %s != %s" % (result, taskhash)) + finally: + client.close() + + # Report hashes + for i in range(1000): + taskhash = hashlib.sha256() + taskhash.update(str(i).encode('utf-8')) + taskhash = taskhash.hexdigest() + self.client.report_unihash(taskhash, self.METHOD, taskhash, taskhash) + + failures = [] + threads = [threading.Thread(target=query_server, args=(failures,)) for t in range(100)] + + for t in threads: + t.start() + + for t in threads: + t.join() + + self.assertFalse(failures) + + +class TestHashEquivalenceUnixServer(TestHashEquivalenceServer, unittest.TestCase): + def get_server_addr(self): + return "unix://" + os.path.join(self.temp_dir.name, 'sock') + + +class TestHashEquivalenceTCPServer(TestHashEquivalenceServer, unittest.TestCase): + def get_server_addr(self): + return "localhost:0" |