User:Surjection/categorylister2.py
Jump to navigation
Jump to search
import urllib.parse import urllib.request import json import operator import sys from collections import OrderedDict from functools import reduce _DEBUG = False class APIURL(): def __init__(self, domain, **params): self.domain = domain self.params = OrderedDict(params) def make(self): parstring = '&'.join( f'{key}={urllib.parse.quote(value)}' for key, value in self.params.items()) return f'https://{self.domain}/w/api.php?{parstring}' def copy(self): return APIURL(self.domain, **self.params) class SiteNameParser(): suffixes = {'wiki': '.wikipedia.org', 'wikt': '.wiktionary.org'} def parse(self, site): for suffix in SiteNameParser.suffixes: if site.endswith(suffix): return site[:-len(suffix)] + SiteNameParser.suffixes[suffix] raise ValueError(f'unrecognized site name: {site}') class CategoryNameParser(): def __init__(self): self.sitenameparser = SiteNameParser() def parse(self, cat): if '|' not in cat: raise ValueError('must be in format site|category') site, category = cat.split('|', 1) deep = False if site.startswith('@'): site, deep = site[1:], True domain = self.sitenameparser.parse(site) return {'url': APIURL(domain, format='json', action='query', list='categorymembers', cmlimit='100', cmtitle='Category:' + category), 'deep': deep} def subcategory(self, url, category): return {'url': APIURL(url.domain, format='json', action='query', list='categorymembers', cmlimit='100', cmtitle=category), 'deep': True} class MWAPI(): def request(self, url): if _DEBUG: print("Making API request to", url.make(), file=sys.stderr) with urllib.request.urlopen(url.make()) as req: result = json.loads(req.read().decode('utf-8')) return result def categorymembers(self, url): while True: result = self.request(url) for c in result['query']['categorymembers']: yield c if 'continue' in result and 'cmcontinue' in result['continue']: url.params['cmcontinue'] = result['continue']['cmcontinue'] else: break class StdinLister(): def collect(self): lines = [] try: while True: lines.append(input()) except EOFError: pass return lines class CategoryLister(): def __init__(self): self.catparser = CategoryNameParser() self.mwapi = MWAPI() def collect_sub(self, url, deep, include_cats=False, ns0=False): pages = [] for page in self.mwapi.categorymembers(url): if (page['ns'] == 0 or not ns0) and (page['ns'] != 14 or include_cats): pages.append(page['title']) if deep and page['ns'] == 14: suburl = self.catparser.subcategory(url, page['title'])['url'] pages += self.collect_sub(suburl, True, include_cats, ns0) return pages def collect(self, category, include_cats=False, ns0=False): if category == '-': return StdinLister().collect() data = self.catparser.parse(category) url, deep = data['url'], data['deep'] return self.collect_sub(url, deep, include_cats, ns0) class MultiCategoryLister(): def __init__(self, operation): self.lister = CategoryLister() self.operation = operation def collect(self, categories, include_cats=False, ns0=False): sets = [set(self.lister.collect(category, include_cats, ns0)) for category in categories] return list(sorted(self.operation(sets))) def set_union(sets): return reduce(operator.or_, sets) def set_intersection(sets): return reduce(operator.and_, sets) def set_difference(sets): return reduce(operator.sub, sets) def set_pairwise_intersection(sets): counts = {item: sum(int(item in set) for set in sets) for item in set_union(sets)} return set(item for item in counts.keys() if counts[item] > 1) def set_symmetric_difference(sets): return reduce(operator.xor, sets) if __name__ == '__main__': import argparse parser = argparse.ArgumentParser() parser.add_argument( '--union', help='all pages in any of the given categories', action='store_true') parser.add_argument( '--intersection', help='all pages in all of the given categories', action='store_true') parser.add_argument( '--pairwise-intersection', help='all pages in at least two of the given categories', action='store_true') parser.add_argument( '--difference', help='all pages in only the first of the given categories', action='store_true') parser.add_argument( '--symmetric-difference', help='all pages in only an odd number of the given categories', action='store_true') parser.add_argument( '--cats', help='include categories in list', action='store_true') parser.add_argument( '--ns0', help='only consider pages in namespace 0 (main namespace)', action='store_true') parser.add_argument( '--limit', help='limit final result to first N pages (0 for all)', nargs='?', default=0, type=int) parser.add_argument( '--output', help='File name, If not specified, goes to stdout', nargs='?', default=None) parser.add_argument('category', nargs='+', help='In the format site|categoryname, such as "enwikt|English lemmas" (prefix with @ to use deep search); use - for stdin') args = parser.parse_args() modes = [('union', set_union), ('intersection', set_intersection), ('pairwise_intersection', set_pairwise_intersection), ('difference', set_difference), ('symmetric_difference', set_symmetric_difference)] modeflags = [getattr(args, name) for name, func in modes] if modeflags.count(True) != 1: if len(args.category) > 1 or modeflags.count(True) > 1: parser.print_help(sys.stderr) parser.error('must specify exactly one mode') # --union by default if only one category operation = set_union else: operation = next(func for name, func in modes if getattr(args, name)) results = MultiCategoryLister(operation).collect(args.category, args.cats, args.ns0) output = sys.stdout if args.output is None else open( args.output, 'w', encoding='utf-8') if args.limit > 0: results = results[:args.limit] for page in results: print(page, file=output) print(f'Total: {len(results)}', file=sys.stderr)