Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 91 additions & 0 deletions luigi/contrib/mrjob.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Some awesome license

# import mrjob
from __future__ import absolute_import

from cStringIO import StringIO
from luigi.task import flatten
import logging
import luigi
import luigi.s3
from datetime import datetime


logger = logging.getLogger('luigi-interface')


class MRJobTask(luigi.Task):
job_class = None
job_args = luigi.Parameter(default=None)
job_options = {}
job_results = None
s3_output_path = None
misc_options = {}

def run_inline(self, job):
print 'gothereherhhe', self
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sambodanis i am still trying to unravel the meanings in the depth of this sentence ;-)

stdin = StringIO()
inlines = self.input()
inlines = flatten(inlines)
inlines = ('\n'.join(inlines)).replace('\n\n', '\n')
stdin.write(inlines)
stdin.seek(0)
job.sandbox(stdin=stdin)
results = []
with job.make_runner() as runner:
runner.run()
for line in runner.stream_output():
if self.misc_options.get('parse_output', False):
key, value = job.parse_output_line(line)
results.append((key, value))
else:
results.append(line)
return results

def run_emr(self, job):
results = []
with job.make_runner() as runner:
runner.run()
for line in runner.stream_output():
key, value = job.parse_output_line(line)
results.append((key, value))

def run(self):
self.job_args = list(self.job_args or []) or ['-r', 'inline', '--no-conf', '--strict-protocols', '-']

logger.info('Running mrjob task with arguments: %s' % ', '.join(self.job_args))

job = self.job_class(self.job_args)
self.job_options = vars(vars(job)['options'])

runner = self.job_options['runner']

if runner == 'inline':
self.job_results = self.run_inline(job)

elif runner == 'emr':

for arg in self.job_args:
if '--output-dir' in arg:
self.s3_output_path = arg.split('=', 1)[-1]

if not self.s3_output_path:
self.s3_output_path = str(self.task_id) + '_' + datetime.utcnow().strftime('%Y-%m-%d-%H-%M-%S')
logger.info('You did not specify an output s3 path. It will be automatically assigned as: %s' %
self.s3_output_path)

self.run_emr(job)
else:
raise NotImplementedError
logger.info('Finished running mrjob')

def complete(self):
return self.job_results is not None or self.s3_output_path is not None

def output(self):
runner = self.job_options['runner']
if runner == 'inline':
return self.job_results
elif runner == 'emr':
return luigi.s3.S3Target(self.s3_output_path)

235 changes: 235 additions & 0 deletions test/contrib/mrjob_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
# Super exciting license
from cStringIO import StringIO

from mrjob.job import MRJob
from mrjob.protocol import RawValueProtocol, RawProtocol
import mock
from mock import patch

from luigi.s3 import S3Target, S3Client
# from luigi.contrib import mrjob
import luigi.contrib
from luigi.contrib.mrjob import MRJobTask
from luigi.task import flatten
import luigi
import luigi.interface
import luigi.worker
import luigi.scheduler
import unittest
import tempfile
import os

# moto does not yet work with
# python 2.6. Until it does,
# disable these tests in python2.6
try:
from moto import mock_s3
except ImportError:
# https://github.com/spulec/moto/issues/29
print('Skipping %s because moto does not install properly before '
'python2.7' % __file__)
from luigi.mock import skip
mock_s3 = skip

AWS_ACCESS_KEY = "XXXXXXXXXXXXXXXXXXXX"
AWS_SECRET_KEY = "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX"


class MRWordFrequencyCount(MRJob):

def mapper(self, _, line):
yield "chars", len(line)
yield "words", len(line.split())
yield "lines", 1

def reducer(self, key, values):
yield key, sum(values)


class MRAddIncrementNumbers(MRJob):

# INPUT_PROTOCOL = RawValueProtocol
OUTPUT_PROTOCOL = RawValueProtocol

def mapper(self, _, number):
yield None, int(number) + 5

def reducer(self, _, values):

yield None, sum(map(int, values))


class DebugTarget(luigi.Target):

def __init__(self, data):
self.data = data

def __call__(self, *args, **kwargs):
return self.data


class DataDump(luigi.ExternalTask):

def run(self):
pass

def complete(self):
return True

def output(self):
return None


class WordCountTask(MRJobTask):
job_class = MRWordFrequencyCount

def requires(self):
return DataDump()


class AddIncrementTask1a(MRJobTask):
job_class = MRAddIncrementNumbers
# job_args = luigi.Parameter()

def requires(self):
return DataDump()


class AddIncrementTask1b(MRJobTask):
job_class = MRAddIncrementNumbers
# job_args = luigi.Parameter()

def requires(self):
return DataDump()


class AddIncrementTask2(MRJobTask):
job_class = MRAddIncrementNumbers
misc_options = {'parse_output': True}

def requires(self):
return AddIncrementTask1a(), AddIncrementTask1b()


class AddIncrementTask3(MRJobTask):
job_class = MRAddIncrementNumbers
misc_options = {'parse_output': True}
s3_bucket = ''

def requires(self):
return AddIncrementTask1a(job_args=['-r', 'emr', '--no-conf', '--strict-protocols',
'--output-dir=s3://mybucket/wc_out/']),\
AddIncrementTask1b(job_args=['-r', 'emr', '--no-conf', '--strict-protocols',
'--output-dir=s3://mybucket/wc_out/'])




class MRJobTaskTest(unittest.TestCase):

def setUp(self):
self.default_args = ['-r', 'inline', '--no-conf', '--strict-protocols', '-']
self.m_data_output = patch.object(DataDump, 'output', autospec=True).start()
self.m_run_emr = patch.object(MRJobTask, 'run_emr', autospec=True).start()
self.m_run_emr.side_effect = lambda self_, _: self.fake_emr(self_)

self.client = S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY)

f = tempfile.NamedTemporaryFile(mode='wb', delete=False)
self.tempFileContents = (
'\n'.join(["hello", "hello", "hello"])
)
self.tempFilePath = f.name
f.write(self.tempFileContents)
f.close()

def tearDown(self):
os.remove(self.tempFilePath)

def run_task_with_input(self, task, input_data):
self.m_data_output.return_value = input_data
task.run()

def test_run_task(self):
task = WordCountTask()
task.misc_options = {'parse_output': True}
self.run_task_with_input(task, ["hello", "hello", "hello"])

result = task.output()
comparison = [('chars', 15), ('lines', 3), ('words', 3)]
self.assertEqual(result, comparison)

def fake_emr(self, task):
args = task.job_args
args[args.index('emr')] = 'inline'
job = task.job_class(args)
stdin = StringIO()
if isinstance(task.input(), list):
inlines = []
for t in task.input():
inlines.append(t.open().read().split())
inlines = flatten(inlines)
else:
inlines = task.input().open().read().split()

inlines = '\n'.join(inlines)
stdin.write(inlines)
stdin.seek(0)
job.sandbox(stdin=stdin)
results = ''
task.job_options['runner'] = 'inline'
with job.make_runner() as runner:
runner.run()
for line in runner.stream_output():
results += line
self.client.put_string(results, task.s3_output_path)
task.job_options['runner'] = 'emr'
return results

@mock_s3
def test_fake_emr(self):
bucket = 'mybucket'
output_path = 's3://mybucket/wc_out/'
input_path = 's3://mybucket/wc_in/'
self.client.s3.create_bucket(bucket)
self.client.put(self.tempFilePath, input_path)

args = ['-r', 'emr', '--no-conf', '--strict-protocols', '--output-dir=' + output_path]
task = WordCountTask()
task.job_args = args
self.run_task_with_input(task, S3Target(input_path, client=self.client))

read_file = task.output().open()
file_str = read_file.read()
self.assertEqual(file_str, '"chars"\t15\n"lines"\t3\n"words"\t3\n')

def test_multiple_tasks(self):
self.m_data_output.return_value = ['1', '2', '3', '4']
task = AddIncrementTask2()
luigi.interface.setup_interface_logging()
sch = luigi.scheduler.CentralPlannerScheduler()
w = luigi.worker.Worker(scheduler=sch)
w.add(task)
w.run()
self.assertEqual(task.output(), [(None, '70\n')])

@mock_s3
def test_multiple_fake_emr(self):
bucket = 'mybucket'
output_path = 's3://mybucket/wc_out/'
input_path = 's3://mybucket/wc_in/'
self.client.s3.create_bucket(bucket)
self.client.put_string('\n'.join(['1', '2', '3', '4']), input_path)

self.m_data_output.return_value = S3Target(input_path, client=self.client)
task = AddIncrementTask3(job_args=['-r', 'emr', '--no-conf',
'--strict-protocols', '--output-dir=' + output_path])

luigi.interface.setup_interface_logging()
sch = luigi.scheduler.CentralPlannerScheduler()
w = luigi.worker.Worker(scheduler=sch)
w.add(task)
w.run()
read_file = task.output().open()
file_str = read_file.read()
self.assertEqual(file_str, '70\n')