'''Utility functions'''
# Copyright 2013 Christopher Foo <chris.foo@gmail.com>
# Licensed under GPLv3. See COPYING.txt for details.
import collections
import datetime
import email.utils
import hashlib
import http.client
import io
import logging
import os
import tempfile
import threading
import urllib.parse
_logger = logging.getLogger(__name__)
[docs]def printable_str_to_str(s):
return s.translate(str.maketrans('', '', '\t\r\n'))\
.replace(r'\r', '\r')\
.replace(r'\n', '\n')\
.replace(r'\t', '\t')
[docs]def find_file_pattern(file_obj, pattern, bufsize=512, limit=4096,
inclusive=False):
'''Find the offset from current position of pattern'''
original_position = file_obj.tell()
bytes_read = 0
# FIXME: don't accumulate growing buffer
search_buf = io.BytesIO()
while True:
if limit:
size = min(bufsize, limit - bytes_read)
else:
size = bufsize
data = file_obj.read(size)
if not data:
break
search_buf.write(data)
try:
index = search_buf.getvalue().index(pattern)
except ValueError:
pass
else:
offset = index
if inclusive:
offset += len(pattern)
file_obj.seek(original_position)
return offset
bytes_read += len(data)
file_obj.seek(original_position)
raise ValueError('Search for pattern exhausted')
[docs]def strip_warc_extension(s):
'''Removes ``.warc`` or ``.warc.gz`` from filename'''
if s.endswith('.gz'):
s = s[:-3]
if s.endswith('.warc'):
s = s[:-5]
return s
[docs]class DiskBufferedReader(io.BufferedIOBase):
'''Buffers the file to disk large parts at a time'''
# Some segments lifted from _pyio.py
# Copyright 2001-2011 Python Software Foundation
# Licensed under Python Software Foundation License Version 2
def __init__(self, raw, disk_buffer_size=104857600, spool_size=10485760):
io.BufferedIOBase.__init__(self)
self._raw = raw
self._disk_buffer_size = disk_buffer_size
self._offset = 0
self._block_index = None
self._block_file = None
self._spool_size = spool_size
self._lock = threading.RLock()
self._cache = FileCache()
self._set_block(0)
def _set_block(self, index):
if index == self._block_index:
return
with self._lock:
self._block_index = index
self._block_file = self._cache.get(self._block_index)
if self._block_file:
_logger.debug('Buffer block file cache hit. index=%d',
self._block_index)
else:
_logger.debug('Creating buffer block file. index=%d',
self._block_index)
self._block_file = tempfile.SpooledTemporaryFile(
max_size=self._spool_size)
self._raw.seek(self._block_index * self._disk_buffer_size)
copyfile_obj(self.raw, self._block_file,
max_length=self._disk_buffer_size)
self._cache.put(self._block_index, self._block_file)
_logger.debug('Buffer block file created. length=%d',
self._block_file.tell())
self._block_file.seek(0)
[docs] def tell(self):
return self._offset
[docs] def seek(self, pos, whence=0):
if not (0 <= whence <= 1):
raise ValueError('Bad whence argument')
with self._lock:
if whence == 1:
self._offset += pos
self._offset = pos
index = self._offset // self._disk_buffer_size
self._set_block(index)
self._block_file.seek(self._offset % self._disk_buffer_size)
[docs] def read(self, n=None):
buf = io.BytesIO()
bytes_left = n
with self._lock:
while True:
self.seek(self._offset)
data = self._block_file.read(bytes_left)
self._offset += len(data)
buf.write(data)
bytes_left -= len(data)
if not data:
break
if bytes_left <= 0:
break
return buf.getvalue()
[docs] def peek(self, n=0):
with self._lock:
original_position = self.tell()
data = self.read(n)
self.seek(original_position)
return data
[docs] def seekable(self):
return self.raw.seekable()
[docs] def readable(self):
return self.raw.readable()
[docs] def writable(self):
return False
@property
def raw(self):
return self._raw
@property
def closed(self):
return self.raw.closed
@property
def name(self):
return self.raw.name
@property
def mode(self):
return self.raw.mode
[docs] def fileno(self):
return self.raw.fileno()
[docs] def isatty(self):
return self.raw.isatty()
[docs]class FileCache(object):
'''A cache containing references to file objects.
File objects are closed when expired. Class is thread safe and will
only return file objects belonging to its own thread.
'''
def __init__(self, size=4):
self._size = size
self._files = collections.deque()
self._lock = threading.Lock()
[docs] def get(self, filename):
thread_id = threading.current_thread()
with self._lock:
return self._get(filename, thread_id)
def _get(self, filename, thread_id):
for cache_filename, cache_thread_id, file_obj in self._files:
if filename == cache_filename and thread_id == cache_thread_id:
return file_obj
[docs] def put(self, filename, file_obj):
thread_id = threading.current_thread()
with self._lock:
if self._get(filename, thread_id):
return
if len(self._files) > self._size:
old_file_obj = self._files.popleft()[2]
old_file_obj.close()
self._files.append((filename, thread_id, file_obj))
[docs]def copyfile_obj(source, dest, bufsize=4096, max_length=None,
write_attr_name='write'):
'''Like :func:`shutil.copyfileobj` but with limit on how much to copy'''
bytes_read = 0
write_func = getattr(dest, write_attr_name)
while True:
if max_length != None:
read_size = min(bufsize, max_length - bytes_read)
else:
read_size = bufsize
data = source.read(read_size)
if not data:
break
write_func(data)
bytes_read += len(data)
[docs]class HTTPSocketShim(io.BytesIO):
[docs] def makefile(self, *args, **kwargs):
return self
[docs]def parse_http_response(file_obj):
'''Parse and return :class:`http.client.HTTPResponse`'''
response = http.client.HTTPResponse(HTTPSocketShim(file_obj))
response.begin()
return response
[docs]def split_url_to_filename(s):
'''Attempt to split a URL to a filename on disk'''
url_info = urllib.parse.urlsplit(s)
l = [sanitize_str(url_info.netloc)]
for part in url_info.path.lstrip('/').split('/'):
part = sanitize_str(part)
l.append(part)
if not part:
l[-1] = append_index_filename(part)
if url_info.query:
l[-1] += '_' + sanitize_str(url_info.query)
if frozenset([os.curdir, os.pardir, '.', '..']) & frozenset(l):
raise ValueError('Path contains directory traversal filenames')
return l
SANITIZE_BLACKLIST = frozenset(
r'/\:*?"<>|' + ''.join([chr(i) for i in range(0, 32)]) + '\x7f'
)
[docs]def sanitize_str(s):
'''Replaces unsavory chracters from string with an underscore'''
return ''.join([c if c not in SANITIZE_BLACKLIST else '_' for c in s])
[docs]def append_index_filename(path):
'''Adds ``_index_xxxxxx`` to the path.
It uses the basename aka filename of the path to generate the hex hash
digest suffix.
'''
hasher = hashlib.sha1(os.path.basename(path).encode())
path += '_index_{}'.format(hasher.hexdigest()[:6])
return path
[docs]def rename_filename_dirs(dest_filename):
'''Renames files if they conflict with a directory in given path.
If a file has the same name as the directory, the file is renamed
using :func:`append_index_filename`.
'''
path = dest_filename
while True:
path, filename = os.path.split(path)
if not filename:
break
if os.path.isfile(path):
new_path = append_index_filename(path)
_logger.debug('Rename %s -> %s', path, new_path)
os.rename(path, new_path)
break
[docs]def truncate_filename_parts(path_parts, length=160):
'''Truncate and suffix filename path parts if they exceed the given length.
If the filename part is too long, the part is truncated and an underscore
plus a 6 letter hex (_xxxxxx) suffix is appended.
'''
new_parts = list(path_parts)
for index in range(len(path_parts)):
part = path_parts[index]
if len(part) > length:
hasher = hashlib.sha1(part.encode())
new_part = '{}_{}'.format(part[:length], hasher.hexdigest()[:6])
new_parts[index] = new_part
return new_parts
[docs]def parse_http_date(s):
t = email.utils.parsedate_tz(s)
if not t:
raise ValueError('Unable to parse date')
tzinfo = datetime.timezone(datetime.timedelta(seconds=t[9]))
d = datetime.datetime(*t[:6], tzinfo=tzinfo)
return d
file_cache = FileCache()
'''The :class:`FileCache` instance'''