www.gusucode.com > bigdata 工具箱 matlab源码程序 > bigdata/@tall/splitapply.m
function varargout = splitapply(fun,varargin) %SPLITAPPLY Split data into groups and apply function % Supported syntaxes for tall arrays: % Y = SPLITAPPLY(FUN,X,G) % Y = SPLITAPPLY(FUN,X1,X2,...,G) % [Y1,Y2,...] = SPLITAPPLY(FUN,...) % % See also FINDGROUPS % Copyright 2016 The MathWorks, Inc. narginchk(3,inf); checkNotTall(upper(mfilename), 0, fun); if ~isa(fun, 'function_handle') error(message('MATLAB:splitapply:InvalidFunction')); end if ~all(cellfun(@istall, varargin)) error(message('MATLAB:bigdata:array:AllArgsTall', upper(mfilename))); end % This exists so that a scalar gnum is expanded to the full length. We need % to do this as bykey and filter do not support singleton expansion. % Further, this converts gnum to a categorical so that grouped operations % can use the categories to determine the known groups instead unique of % of the each chunk. gnum = tall.validateType(varargin{end}, upper(mfilename), {'numeric'}, nargin); gnum = slicefun(@iValidateInputsAndExtractGnumAsCategorical, varargin{1:end - 1}, gnum); gnum.Adaptor = setSmallSizes(matlab.bigdata.internal.adaptors.CategoricalAdaptor(), 1); varargin(end) = []; % This is to support the syntax splitapply(fun, table(..), gnum) varargin = iFlattenTableInputs(varargin); [varargout{1:max(1, nargout)}] = iSplitApply(fun, gnum, varargin{:}); % The main implementation after input argument parsing. function varargout = iSplitApply(fun, gnum, varargin) markerFrame = matlab.bigdata.internal.InternalStackFrame(); %#ok<NASGU> session = matlab.bigdata.internal.lazyeval.GroupedPartitionedArraySession(fun); sessionCleanup = onCleanup(@session.close); try [varargin{:}, gnum] = validateSameTallSize(varargin{:}, gnum); uniqueGnum = unique(gnum); [varargin{:}] = iWrapTallAsGroupedTall(gnum, session, varargin{:}); [varargout{1:nargout}] = fun(varargin{:}); for ii = 1:numel(varargout) varargout{ii} = iParseOutput(varargout{ii}, uniqueGnum, session); end catch err err = matlab.bigdata.internal.lazyeval.parseExecutionError(err); matlab.bigdata.BigDataException.hThrowAsCallerWithSubmissionStack(err); end % Parse the output generated by the function handle. % % This is to handle both the case where the function handle emits a cell % array or other non-tall array. function data = iParseOutput(data, uniqueGnum, session) if iscell(data) data = iParseCellOutput(data); end if istall(data) && isa(hGetValueImpl(data), 'matlab.bigdata.internal.lazyeval.GroupedPartitionedArray') data = iUnWrapGroupedTall(data); else funStr = matlab.bigdata.internal.broadcast(func2str(session.FunctionHandle)); data = clientfun(@iCreateOutputFromScalar, uniqueGnum, data, funStr); end % Parse a cell array output generated by the function handle. function data = iParseCellOutput(data) if ~iscell(data) return; end data = cellfun(@iParseCellOutput, data, 'UniformOutput', false); if any(cellfun(@istall, data)) % If the cell array contains any tall arrays, we need to bring the tall % attribute up one level so that the cell array is per group. data = clientfun(@(sz, varargin) reshape(varargin, sz), size(data), data{:}); end % Wrap a collection of tall arrays each as a grouped tall array. function varargout = iWrapTallAsGroupedTall(keys, session, varargin) import matlab.bigdata.internal.lazyeval.GroupedPartitionedArray; keys = hGetValueImpl(keys); pv = cellfun(@hGetValueImpl, varargin, 'UniformOutput', false); [varargout{1:nargout}] = GroupedPartitionedArray.create(keys, session, pv{:}); for ii = 1:numel(varargout) varargout{ii} = tall(varargout{ii}); varargout{ii}.Adaptor = resetTallSize(varargin{ii}.Adaptor); end % Unwrap a grouped tall array function out = iUnWrapGroupedTall(in) gpv = hGetValueImpl(in); funStr = matlab.bigdata.internal.broadcast(func2str(gpv.Session.FunctionHandle)); out = tall(clientfun(@iCreateOutput, gpv.Keys, gpv.Values, funStr)); out.Adaptor = resetTallSize(in.Adaptor); % Validate and extract the gnum input. This also ensures all non-gnum inputs % are not singleton in the tall dimension. function gnum = iValidateInputsAndExtractGnumAsCategorical(varargin) sz = size(varargin{1}, 1); for ii = 2 : numel(varargin) - 1 if size(varargin{ii}) ~= sz error(message('MATLAB:bigdata:array:IncompatibleTallStrictSize')); end end gnum = varargin{end}; gnumIsValid = isnumeric(gnum) && iscolumn(gnum) && all(isnan(gnum) | mod(gnum, 1) == 0 & gnum > 0); % We set NaN values to hidden group gnum 0. This is so at the end we know % whether NaN values existed in the input. gnum(isnan(gnum)) = 0; if ~gnumIsValid % Depending on the partitioning and order of execution, gnum could be invalid % either by being a non-scalar row, or a matrix. Here we throw a single % error that covers both cases to ensure a consistent error is returned. error(message('MATLAB:bigdata:array:SplitApplyUnsupportedGroupNums')); end if size(gnum, 1) == 1 gnum = gnum .* ones(sz, 1); elseif size(gnum, 1) ~= sz error(message('MATLAB:bigdata:array:IncompatibleTallStrictSize')); end gnum = categorical(gnum); % Helper function that flattens all table inputs. function flattenedInputs = iFlattenTableInputs(inputs) flattenedInputs = cell(size(inputs)); for inputIndex = 1:numel(inputs) if strcmp(tall.getClass(inputs{inputIndex}), 'table') variableNames = inputs{inputIndex}.Adaptor.VariableNames; flattenedInputs{inputIndex} = cell(1, numel(variableNames)); for varIndex = 1:numel(variableNames) flattenedInputs{inputIndex}{varIndex} = subsref(inputs{inputIndex}, substruct('.', variableNames{varIndex})); end else flattenedInputs{inputIndex} = inputs(inputIndex); end end flattenedInputs = [flattenedInputs{:}]; % Create the output from a scalar value. This is for user functions that do % not return a grouped tall array. function values = iCreateOutputFromScalar(gnum, values, funStr) if size(values, 1) ~= 1 idx = matlab.internal.tableUtils.ordinalString(1); error(message('MATLAB:bigdata:array:SplitApplyOutputNotUniform', funStr, idx)); end values = repmat(values, numel(gnum), 1); values = iCreateOutput(gnum, values, funStr); % Create the output. This does no work except for error checking. function values = iCreateOutput(gnum, values, funStr) gnum = double(string(gnum)); [gnum, idx] = sort(gnum); values = matlab.bigdata.internal.util.indexSlices(values, idx); isNanGroup = (gnum == 0); if any(isNanGroup) gnum(isNanGroup) = []; values = matlab.bigdata.internal.util.indexSlices(values, ~isNanGroup); end if isempty(gnum) if any(isNanGroup) return; else error(message('MATLAB:splitapply:InvalidGroupNums')); end end expectedGnum = (1:max(gnum))'; if ~isequal(gnum, expectedGnum) if ~isempty(setdiff(expectedGnum, gnum)) error(message('MATLAB:splitapply:MissingGroupNums')); else idx = gnum(find(gnum(1:max(gnum)) ~= expectedGnum, 1, 'first')); idx = matlab.internal.tableUtils.ordinalString(idx); error(message('MATLAB:bigdata:array:SplitApplyOutputNotUniform', funStr, idx)); end end