The visitor design pattern is a powerful programming pattern that I was not fundamentally aware. I knew the pattern, I had seen code examples using that pattern but I hadn’t realized how powerful and elegant it is. I confess that I overlooked at it, but the description in sites/blogs/book usually makes it confusing to me. Nonetheless, I should have gone through the practical aspects of it instead of just reading the concepts…lessons learned!

Traversing the Object Hierarchy

It’s simply there in the wikipedia:

Clients traverse the object structure and call a dispatching operation accept(visitor) on an element — that “dispatches” (delegates) the request to the “accepted visitor object”. The visitor object then performs the operation on the element (“visits the element”).

But, I only really read it after read the excellent Crafting Interpreters book.

The Crafting Interpreters uses the visitor pattern to implement the abstract syntax tree (AST) where the object hierarchy (which was created during the parsing phase) is the tree. Clients (be it an interpreter, a compiler, or even something to print the AST) can traverse the tree hierarchy by “visiting” the objects. Cool.

Coding

This code is obviously based on the referred book but I only wrote parts to highlight the visitor pattern.

 1    class expression_t
 2    {
 3    public:
 4        virtual int accept(expression_visitor &v) = 0;
 5        virtual ~expression_t() {}
 6    };
 7    using expression = unique_ptr<expression_t>;
 8    
 9    class type_number : public expression_t
10    {
11        int value_;
12    
13    public:
14        type_number(int v) :
15            value_(v)
16        {
17        }
18    
19        type_number(string v) :
20            value_(stoi(v))
21        {
22        }
23    
24        int get() const
25        {
26            return value_;
27        }
28    
29        int accept(expression_visitor &v)
30        {
31            return v.visit_number(*this);
32        }
33    };
34    
35    class binary_expression : public expression_t
36    {
37        expression left_;
38        expression right_;
39        char oper_;
40    
41        public:
42        binary_expression(expression l,
43                          expression r,
44                          char o) :
45            left_(move(l)),
46            right_(move(r)),
47            oper_(o)
48        {
49        }
50    
51        expression_t *left()
52        {
53            return left_.get();
54        }
55    
56        expression_t *right()
57        {
58            return right_.get();
59        }
60    
61        char operation() const
62        {
63            return oper_;
64        }
65    
66        int accept(expression_visitor &v)
67        {
68            return v.visit_binary(*this);
69        }
70    };

Nothing special here. An abstract class with a pure virtual method called accept() and two classes implementing it. One represents numbers, other represents binary expressions like 3 * 5. The accept() simply calls a method from the visitor passing the instance of itself as argument.

 1    class expression_visitor
 2    {
 3    public:
 4        virtual int  visit_number(type_number &v) = 0;
 5        virtual int visit_binary(binary_expression &v) = 0;
 6        virtual ~expression_visitor();
 7    };
 8    
 9    class interpreter : public expression_visitor
10    {
11        int visit_number(type_number &n)
12        {
13            return n.get();
14        }
15    
16        int visit_binary(binary_expression &b)
17        {
18            int left = evaluate(b.left());
19            int right = evaluate(b.right());
20    
21            switch (b.operation()) {
22                case '+':
23                    return left + right;
24                    break;
25    
26                case '-':
27                    return left - right;
28                    break;
29    
30                case '*':
31                    return left * right;
32                    break;
33    
34                case '/':
35                    return left / right;
36                    break;
37            }
38    
39            throw exception();
40        }
41    
42        int evaluate(expression_t *e)
43        {
44            return e->accept(*this);
45        }
46    };

Now, the visitor. The interpreter implements the visitor, giving meaning to each object that it needs to visit. Note that visit_number() returns the value stored in type_number and visit_binary() evaluates both left() and right() expressions - that can hold either type_number or other binary_expression. In other words, visit_binary() be called recursively until it finds a type_number. Isn’t it beautiful and elegant?

visitor diagram

Full code listing

Here is the full code listing:

  1    #include <iostream>
  2    #include <string>
  3    #include <memory>
  4    #include <sstream>
  5    
  6    using namespace std;
  7    
  8    class type_number;
  9    class type_string;
 10    class binary_expression;
 11    class unary_expression;
 12    
 13    class expression_visitor
 14    {
 15    public:
 16        virtual int  visit_number(type_number &v) = 0;
 17        virtual int visit_binary(binary_expression &v) = 0;
 18        virtual ~expression_visitor() {}
 19    };
 20    
 21    class expression_t
 22    {
 23    public:
 24        virtual int accept(expression_visitor &v) = 0;
 25        virtual ~expression_t() {}
 26    };
 27    using expression = unique_ptr<expression_t>;
 28    
 29    class type_number : public expression_t
 30    {
 31        int value_;
 32    
 33    public:
 34        type_number(int v) :
 35            value_(v)
 36        {
 37        }
 38    
 39        type_number(string v) :
 40            value_(stoi(v))
 41        {
 42        }
 43    
 44        int get() const
 45        {
 46            return value_;
 47        }
 48    
 49        int accept(expression_visitor &v)
 50        {
 51            return v.visit_number(*this);
 52        }
 53    };
 54    
 55    class binary_expression : public expression_t
 56    {
 57        expression left_;
 58        expression right_;
 59        char oper_;
 60    
 61        public:
 62        binary_expression(expression l,
 63                          expression r,
 64                          char o) :
 65            left_(move(l)),
 66            right_(move(r)),
 67            oper_(o)
 68        {
 69        }
 70    
 71        expression_t *left()
 72        {
 73            return left_.get();
 74        }
 75    
 76        expression_t *right()
 77        {
 78            return right_.get();
 79        }
 80    
 81        char operation() const
 82        {
 83            return oper_;
 84        }
 85    
 86        int accept(expression_visitor &v)
 87        {
 88            return v.visit_binary(*this);
 89        }
 90    };
 91    
 92    class interpreter : public expression_visitor
 93    {
 94        int visit_number(type_number &n)
 95        {
 96            return n.get();
 97        }
 98    
 99        int visit_binary(binary_expression &b)
100        {
101            int left = evaluate(b.left());
102            int right = evaluate(b.right());
103    
104            switch (b.operation()) {
105                case '+':
106                    return left + right;
107                    break;
108    
109                case '-':
110                    return left - right;
111                    break;
112    
113                case '*':
114                    return left * right;
115                    break;
116    
117                case '/':
118                    return left / right;
119                    break;
120            }
121    
122            throw exception();
123        }
124    
125        int evaluate(expression_t *e)
126        {
127            return e->accept(*this);
128        }
129    
130        public:
131        void compute(expression x)
132        {
133            cout << evaluate(x.get()) << endl;
134        }
135    };
136    
137    class parser
138    {
139        private:
140        expression parse(string s)
141        {
142            stringstream tokens;
143            tokens << s;
144            return parse_add_sub(tokens);
145        }
146    
147        expression parse_add_sub(stringstream &tk)
148        {
149            expression left = parse_mult_div(tk);
150    
151            while (tk.peek() == '+' || tk.peek() == '-') {
152                char operation = tk.get();
153                expression right = parse_mult_div(tk);
154                left = make_unique<binary_expression>(move(left),
155                                                      move(right),
156                                                      operation);
157            }
158    
159            return left;
160        }
161    
162        expression parse_mult_div(stringstream &tk)
163        {
164            expression left = parse_number(tk);
165    
166            while (tk.peek() == '*' || tk.peek() == '/') {
167                char operation = tk.get();
168                expression right = parse_number(tk);
169                left = make_unique<binary_expression>(move(left),
170                                                      move(right),
171                                                      operation);
172            }
173    
174            return left;
175        }
176    
177        expression parse_number(stringstream &tk)
178        {
179            string sval;
180            tk >> sval;
181            int value = 0;
182    
183            try {
184                value = stoi(sval);
185            }
186            catch (invalid_argument &e) {
187                cerr << "expected a number, found " << sval << endl;
188                exit(1);
189            }
190            catch (out_of_range &e) {
191                cerr << "number " << sval << " overflows an integer storage" << endl;
192                exit(1);
193            }
194    
195            while (tk.peek() == ' ')
196                tk.get();
197    
198            return make_unique<type_number>(value);
199        }
200    
201        public:
202        expression parse_it(string s)
203        {
204            return parse(s);
205        }
206    };
207    
208    int main()
209    {
210        interpreter it;
211        parser p;
212    
213        while (true) {
214            string line;
215    
216            getline(cin, line);
217            if (line == "quit")
218                break;
219    
220            it.compute(p.parse_it(line));
221        }
222        return 0;
223    }
    g++ -std=c++14 -Wall -Wextra -g visitors.cpp -o visitors
    % ./visitors
    3 * 5 + 8 - 3
    20
    15 * 80 / 2 + 3 * 8
    624
    quit